diff --git a/docs/mindspore/source_en/design/distributed_training_design.md b/docs/mindspore/source_en/design/distributed_training_design.md
index 302cfc6acf6bf80ab4ced696e8c3b7d87ed27e93..0cee1d6ae1700a42043bd6910e65533a87dd9e6a 100644
--- a/docs/mindspore/source_en/design/distributed_training_design.md
+++ b/docs/mindspore/source_en/design/distributed_training_design.md
@@ -120,3 +120,7 @@ As a key feature of MindSpore, automatic parallelism is used to implement hybrid
6. Backward propagation of communication operators
- [grad_comm_ops.py](https://gitee.com/mindspore/mindspore/blob/master/mindspore/python/mindspore/ops/_grad/grad_comm_ops.py): This file defines the backward propagation of communication operators, such as `AllReduce` and `AllGather`.
+
+## Heterogeneous Parallelism
+
+Subgraphs in different hardware and without dependencies can also support parallel execution. For detailed information, refer to [Heterogeneous Parallel Training](https://www.mindspore.cn/docs/en/master/design/heterogeneous_training.html).
\ No newline at end of file
diff --git a/docs/mindspore/source_en/design/heterogeneous_training.md b/docs/mindspore/source_en/design/heterogeneous_training.md
new file mode 100644
index 0000000000000000000000000000000000000000..2b91ecb63c478d168fea48d1c7cf629358af9702
--- /dev/null
+++ b/docs/mindspore/source_en/design/heterogeneous_training.md
@@ -0,0 +1,203 @@
+# Heterogeneous Parallel Training
+
+
+
+## Overview
+
+The heterogeneous parallel training method is to analyze the memory occupation and computational intensity of the operators on the graph, and slice the operators with huge memory consumption or suitable for CPU logic processing to the CPU subgraph, and slice the computationally intensive operators with less memory consumption to the hardware accelerator subgraph. The framework cooperates with different subgraphs for network training, so that subgraphs in different hardware and without dependencies can perform the execution process in parallel.
+
+## Computational Process
+
+A typical computational process for MindSpore heterogeneous parallel training is shown in the following figure:
+
+1. Users set backend for network execution
+
+```python
+import mindspore as ms
+ms.set_context(device_target="GPU")
+```
+
+2. Users set execution backend of specific operators
+
+```python
+from mindspore import ops
+
+prim = ops.Add()
+
+prim.set_device("CPU")
+```
+
+3. The framework is sliced according to the computational graph operator flag.
+4. The framework schedules different back-end execution subgraphs.
+
+Current scenarios that typically use heterogeneous parallel computing are: optimizer heterogeneity, Embedding heterogeneity, and PS heterogeneity.
+
+## Optimizer Heterogeneity
+
+During the training of a large model in PanGu or GPT3, the optimizer state takes up a large amount of memory, which in turn limits the size of the model that can be trained. Using optimizer heterogeneity, assigning optimizers to CPUs for execution can greatly scale the trainable models:
+
+
+
+As shown in the figure, configuring the Adam operator to CPU execution while specifying an accelerator for FP16 computation reduces the parameter memory footprint to 1/3 of the original.
+
+1. Configure the optimizer operators to CPU execution
+2. Initialize weight parameters of FP16 and optimizer state variables of FP32
+3. Convert the gradient of the input optimizer to FP16 (if the gradient is FP16, you can ignore this step)
+4. The weights and gradients are converted to FP32 to participate in the optimizer operation
+5. The updated FP32 weights are assigned to the FP16 weights
+
+Sample code of the optimizer heterogeneity is as follows:
+
+```python
+import numpy as np
+import mindspore as ms
+import mindspore.ops as ops
+from mindspore.common.initializer import initializer
+from mindspore.nn import Optimizer
+_adam_opt = ops.MultitypeFuncGraph("adam_opt")
+host_assign = ops.Assign()
+host_assign.set_device("CPU")
+host_cast = ops.Cast()
+host_cast.set_device("CPU")
+device_cast = ops.Cast()
+
+@_adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
+ "Tensor", "Bool", "Bool")
+def _update_run_kernel(opt, beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, decay_flags, optim_filter):
+ """
+ Update parameters by AdamWeightDecay op.
+ """
+ success = True
+ if optim_filter:
+ param32 = host_cast(param, ms.float32)
+ gradient = device_cast(gradient, ms.float32)
+ if decay_flags:
+ next_param = opt(param32, m, v, lr, beta1, beta2, eps, weight_decay, gradient)
+ else:
+ next_param = opt(param32, m, v, lr, beta1, beta2, eps, 0.0, gradient)
+ ret = host_assign(param, host_cast(ops.depend(param32, next_param), ops.dtype(param)))
+ return ops.depend(success, ret)
+ return success
+
+class AdamWeightDecayOp(Optimizer):
+ def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):
+ super(AdamWeightDecayOp, self).__init__(learning_rate, params, weight_decay)
+ self.beta1 = ms.Tensor(np.array([beta1]).astype(np.float32))
+ self.beta2 = ms.Tensor(np.array([beta2]).astype(np.float32))
+ self.eps = ms.Tensor(np.array([eps]).astype(np.float32))
+ self.moments1 = self.clone_param32(prefix="adam_m", init='zeros')
+ self.moments2 = self.clone_param32(prefix="adam_v", init='zeros')
+ self.opt = ops.AdamWeightDecay()
+ self.hyper_map = ops.HyperMap()
+ self.opt.set_device("CPU")
+
+ def construct(self, gradients):
+ """AdamWeightDecayOp"""
+ lr = self.get_lr()
+ if self.is_group:
+ if self.is_group_lr:
+ optim_result = self.map_reverse(ops.partial(_adam_opt, self.opt, self.beta1, self.beta2, self.eps),
+ lr, self.weight_decay, self.parameters, self.moments1, self.moments2,
+ gradients, self.decay_flags, self.optim_filter)
+ else:
+ optim_result = self.map_reverse(ops.partial(_adam_opt, self.opt, self.beta1, self.beta2, self.eps, lr),
+ self.weight_decay, self.parameters, self.moments1, self.moments2,
+ gradients, self.decay_flags, self.optim_filter)
+ else:
+ optim_result = self.map_reverse(ops.partial(_adam_opt, self.opt, self.beta1, self.beta2, self.eps, lr,
+ self.weight_decay), self.parameters, self.moments1, self.moments2,
+ gradients, self.decay_flags, self.optim_filter)
+ return optim_result
+
+ def clone_param32(self, prefix, init=None):
+ new = []
+ for old_param in self.parameters:
+ param_init = init
+ if init is None:
+ param_init = old_param.init
+ new_state = old_param.clone()
+ new_state.set_dtype(ms.float32)
+ new_state.set_data(initializer(param_init, shape=old_param.shape, dtype=ms.float32))
+ new_state.name = prefix + '.' + new_state.name
+ new.append(new_state)
+ return ms.ParameterTuple(new)
+```
+
+Steps 4 and 5 can also be directly fused into the optimizer operator for further optimization. The complete optimizer heterogeneous training process can be found at: .
+
+## Embedding Heterogeneity
+
+In some networks where large Embedding tables need to be checked, the Embedding tables are often hundreds of gigabytes in size, which is limited by the accelerator memory size and cannot be executed by loading the entire table directly onto the accelerator. By putting the operators connected to the weight table on the CPU for execution, we avoid the problem that the accelerator cannot train the network due to memory limitation.
+
+
+
+1. Configure EmbeddingLookup operator to CPU execution
+
+ ```python
+ ops.EmbeddingLookup().set_device('CPU')
+ ```
+
+2. Configure related optimizers of EmbeddingLookup to CPU execution
+
+ ```python
+ use_locking = False
+ use_nesterov = False
+ ops.FusedSparseLazyAdam(use_locking, use_nesterov).set_device("CPU")
+ ```
+
+A sample code for setting up the EmbeddingLookup operator is as follows:
+
+```python
+import mindspore.nn as nn
+import mindspore.ops as ops
+import mindspore as ms
+from mindspore.common.initializer import initializer
+
+class EmbeddingLookup(nn.Cell):
+ def __init__(self, vocab_size, embedding_size, param_init='normal',
+ target='CPU', sparse=True):
+ """Initialize EmbeddingLookup."""
+ super(EmbeddingLookup, self).__init__()
+ validator.check_value_type('sparse', sparse, [bool], self.cls_name)
+ self.vocab_size = validator.check_positive_int(vocab_size, 'vocab_size')
+ self.target = target
+ self.sparse = sparse
+ if target not in ('CPU', 'DEVICE'):
+ raise ValueError('Attr \'target\' of \'EmbeddingLookup\' Op passed '
+ + str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.')
+ if not sparse and target == 'CPU':
+ raise ValueError('When target is CPU, embedding_lookup must be sparse.')
+ if sparse:
+ self.gatherv2 = ops.SparseGatherV2()
+ else:
+ self.gatherv2 = ops.Gather()
+ self.embeddinglookup = ops.EmbeddingLookup().set_device('CPU')
+ self.embedding_size = validator.check_positive_int(embedding_size, 'embedding_size')
+ self.embedding_table = ms.Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]),
+ name='embedding_table')
+
+ def construct(self, indices):
+ if self.target == "CPU":
+ out = self.embeddinglookup(self.embedding_table, indices, 0)
+ else:
+ out = self.gatherv2(self.embedding_table, indices, 0)
+ return out
+```
+
+EmbeddingLookup, FTRL, LazyAdam and other operators in the current nn directory are encapsulated the heterogeneous interface, and the user only needs to set the target attribute to CPU or DEVICE to switch the execution backend.
+
+For the overall calling process, refer to .
+
+## PS Heterogeneity
+
+When the EmbeddingTable reaches T level and the single machine memory cannot be put down, Parameter Server is used to pull and update the weights by heterogeneous Pull/Push operators.
+
+
+
+Parameter Server encapsulates heterogeneous processes, and users only need to configure parameters to use PS. For the detailed configuration process, refer to [Parameter Server training process](https://www.mindspore.cn/tutorials/experts/en/master/parallel/parameter_server_training.html).
+
+In addition, the process of using PS is also available in the wide&deep network and can be found at: .
+
+## Constraints
+
+Currently requires the user to specify the back-end of the operator execution and does not support automatic configuration based on the network.
diff --git a/docs/mindspore/source_en/design/side_effect.md b/docs/mindspore/source_en/design/side_effect.md
new file mode 100644
index 0000000000000000000000000000000000000000..30d312e9ad7552dc37cd88576672b7f70f2750f0
--- /dev/null
+++ b/docs/mindspore/source_en/design/side_effect.md
@@ -0,0 +1,133 @@
+# Side Effects
+
+
+
+## Concepts
+
+### Pure Function
+
+A function whose return value depends only on the real parameters of the function and has no side effects is a pure function.
+A pure function is closer to a function in the mathematical sense: for the same input parameters, users always get the same return value.
+If the program contains only pure functions, the order in which they are evaluated will not affect the program result.
+For example, in the following code, assuming that `add` is a pure function, the order in which `a` and `b` are evaluated will not affect the result of `c`.
+
+```python
+ a = add(1, 2)
+ b = add(3, 4)
+ c = add(a, b)
+```
+
+### Side Effects
+
+A function has side effects if it changes the external state.
+Or there are other observable effects occurring besides the return value of the function in the functions with side effects.
+For example modifying global variables, modifying the value of reference type parameters, executing input and output operations, calling other functions with side effects.
+When there are side effects, the behavior of the program may change depending on the different order in which the values are evaluated.
+For example, in the following code, suppose `add` is a pure function and `assign` is a function with side effects (it changes the input parameter x), the different order in which `a`, `b` and `c` are evaluated will lead to different results for `d`.
+
+```python
+ a = add(1, x)
+ b = assign(x, 100) # side effect
+ c = add(3, x)
+ d = add(a, c)
+```
+
+Because of the side effects, `a`, `b` and `c` in the above program should be evaluated strictly in the order in which they are in the code, otherwise they will produce unintended results.
+
+## Design
+
+MindSpore uses a functional intermediate representation based on a graph representation, and refer to [MindIR](https://www.mindspore.cn/docs/en/master/design/mindir.html).
+Conceptually, the functions in MindIR are pure functions and do not have side effects.
+However, MindSpore can support computational models with side effects and provide operators with side effects, such as optimizer operators that will directly modify the input parameters.
+In order to support operator and computational models with side effects, MindSpore converts the side effects in the code to pure functional form when compiling the model. This ensures that computations with side effects are executed in the desired order while keeping MindIR pure functional semantics.
+
+### Converting Side Effects to Pure Functions
+
+To be able to convert a function with side effects to a pure function form, MindSpore treats the external state affected by the side effect function as a data object. The modification of the external state by the function is then converted to a state object as the input to the function, and the modified state object is returned:
+
+```python
+ ret = func_with_side_effect(args)
+```
+
+converted as:
+
+```python
+ ret, state1 = pure_func(args, state0)
+```
+
+Here the return value of `pure_func` depends only on the input parameters, and the input state `state0` is unchanged and the updated state `state1` is returned, so it can be seen as a pure function.
+
+### Intermediate Representation of Side Effects
+
+Since MindIR functions do not support multiple return values, MindSpore introduces a virtual operator `UpdateState`. The above `pure_func` function is expressed as an intermediate representation of the following form:
+
+```python
+ ret = pure_func(args, state0)
+ state1 = UpdateState(state0, ret)
+```
+
+In addition, to ensure the correct order of reading and writing, MindSpore introduces a `Load` operator. If the input to a function is a global parameter, a `Load` is inserted to ensure that the function reads the correct parameter value.
+For example, `add` in the following code needs to read in a global parameter `param`:
+
+```python
+ out = add(self.param, x)
+```
+
+MindSpore converts this to an intermediate representation of the following form:
+
+```python
+ p = Load(self.param, state0)
+ state1 = UpdateState(state0, p)
+ out = add(p, x)
+```
+
+### Classifications of Side Effects
+
+MindSpore classifies side effects into three types, depending on the different external state types influenced by side effects:
+
+1. Memory side effects: Affecting the state in memory, such as modifying global variables, modifying input parameters.
+
+2. Input and output side effects: With input and output operations, such as printing information to the console.
+
+3. Hidden side effects: There is no obvious external state change, but there is an actual hidden state change. For example, the random number generation operator affects the state of the random number generator.
+
+In MindSpore, memory side effects and input and output side effects are represented by separate state objects, so these two types of side effects are represented as two separate execution sequences.
+
+Hidden side effects are not reflected as separate state objects and execution sequences because there is no explicit external state counterpart, but MindSpore internally performs some special processing on it, such as preventing the fusion of two random number generation operators to prevent generating wrong results.
+
+### Side Effect Operator Mark
+
+Operators are marked whether there are side effects by adding specific attributes. MindSpore supports the following attributes to mark the side effects of an operator.
+
+- side_effect_mem: Memory side effect
+- side_effect_io: Input and output side effect
+- side_effect_hidden: Hidden side effect
+
+For example, to mark an operator as having memory side effects:
+
+```python
+ @prim_attr_register
+ def __init__(self):
+ ...
+ self.add_prim_attr('side_effect_mem', True)
+```
+
+MindSpore can ensure that the side effects are executed in the desired order only if they are correctly identified.
+
+## Related Scenarios
+
+MindSpore automatically identifies side effects in the code and ensures that they are executed in the correct order.
+
+Therefore, in the great majority of cases, model developers and users do not need to be concerned about whether the model has side effects and how to ensure the correct order of execution.
+
+### Operator Development
+
+If the developed operator is considered to have side effects, it needs to be correctly identified that there are side effects and what kind of side effects they are by the operator properties. Otherwise, there is a risk that the model by using the operator may lead to incorrect results because the evaluation sequence is not performed as expected.
+
+### Model Development
+
+Typically, model developers do not need to be concerned with side effects, but understanding the side effect rationale may be helpful in anticipating the order of code execution. Also by knowing which operators have side effects, one can make better operator choices.
+
+### MindIR
+
+If the model has side effects, `UpdateState` and `Load` nodes exist in the exported MindIR, and their role is to handle side effects and order preservation.