diff --git a/tools/flight_recorder/components/builder.py b/tools/flight_recorder/components/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..eb13bdc9ae9c819a33c3de4392cf6ba7423083db --- /dev/null +++ b/tools/flight_recorder/components/builder.py @@ -0,0 +1,307 @@ +import argparse +import ast +import os +from typing import Any + +from tools.flight_recorder.components.fr_logger import FlightRecorderLogger +from tools.flight_recorder.components.types import ( + Collective, + Database, + EntryState, + Group, + MatchStateRecord, + Membership, + HCCLCall, + Op, + Traceback, +) +from tools.flight_recorder.components.utils import ( + ProcessGroupData, + align_trace_from_beginning, + check_current_entry_match, + check_no_missing_dump_files, + check_version, + EntryContext, + error_analysis, + get_version_detail, + just_print_entries, +) + + +# Set up logging +logger: FlightRecorderLogger = FlightRecorderLogger() + + +try: + from tabulate import tabulate +except ModuleNotFoundError: + logger.warning("tabulate is not installed. Proceeding without it.") + + # Define a no-op tabulate function + def tabulate(data: Any, headers: Any = None) -> Any: # type: ignore[misc] + return data + + +""" +Flat DB builder +""" + + +def build_groups_memberships( + pg_config: Any, +) -> tuple[ + list[Group], + dict[Any, Group], + list[Membership], + dict[str, set[Any]], + dict[tuple[str, int], str], +]: + """ + pg_config: { + global_rank: { + (pg_guid, desc, ranks) + } + } + + `pg_guid` is a system generated id, but depending on the mode of PG creation it could be a globally incrementing int + or a hash of the ranks. See `_process_group_name` in distributed_c10d.py. + `desc` is provided by the user (optionally) and should be 'meaningful' (e.g. TP/PP/DP group) + `ranks` is a list of the 'global ranks' that are members of the PG. + + (pg_guid, desc, ranks) tuples are appended lazily to the flight buffer when `getHCCLComm` is called on a PG and + the `enabled_` flag is true for that PG. + - the order of calling (init_process_group, new_group, etc) does not affect the order of the tuples in the list + + Returns: + `groups`: a groups table where each row is a Group namedtuple. + `_groups`: a dict that is indexed by pg_guid with Group namedtuple as value. + `memberships`: a membership table where each row is a Membership namedtuple. + `_memberships`: a dict that is indexed by pg_guid with set of ranks (int) as value. + `_pg_guids`: a dict that is indexed by (pg_uid, global_rank) with pg_guid as value. + """ + # flat lists for return + groups = [] + memberships = [] + + # dicts for faster cross-rank validation + _groups = {} + _memberships = {} + _pg_guids = {} + for global_rank in pg_config: + for pg_uid in pg_config[global_rank]: + desc = pg_config[global_rank][pg_uid]["desc"] + ranks = ast.literal_eval(pg_config[global_rank][pg_uid]["ranks"]) + # With the adoption of the split_group API, we can have multiple PGs with the same pg_guid (PG Name) + # So we need to add the hash of all its ranks within the PG as well. + # Also guid must be a string because `_process_group_name` returns a string. + pg_guid = pg_uid + str(hash(frozenset(ranks))) + _pg_guids[(pg_uid, global_rank)] = pg_guid + if isinstance(ranks, str): + ranks = ast.literal_eval(ranks) + if pg_guid not in _groups: + groups.append(Group(id=pg_guid, desc=desc, size=len(ranks))) + for rank in ranks: + memberships.append(Membership(group_id=pg_guid, global_rank=rank)) + _groups[pg_guid] = groups[-1] + _memberships[pg_guid] = set(ranks) + else: + # validation across ranks + if _groups[pg_guid].desc != desc: + raise ValueError( + f"Description mismatch for group {pg_guid}: " + f"expected '{desc}', got '{_groups[pg_guid].desc}'" + ) + + if _memberships[pg_guid] != set(ranks): + raise ValueError( + f"Membership mismatch for group {pg_guid}: " + f"expected {set(ranks)}, got {_memberships[pg_guid]}" + ) + + return groups, _groups, memberships, _memberships, _pg_guids + + +def build_collectives( + all_entries: dict[int, list[dict[str, Any]]], + _groups: dict[str, Group], + _memberships: dict[str, set[Any]], + _pg_guids: dict[tuple[str, int], str], + version: str, +) -> tuple[list[Traceback], list[Collective], list[HCCLCall]]: + """ + groups, memberships are the non-flat dicts that are indexable + all_entries is a raw dict from the original dumps: + + all_entries: { + global_rank: [ + { + record_id: ordered id of the event in the trace buffer + pg_id: ProcessGroupHCCL::uid_ + *note: `pg_id` corresponds to nothing in groups table + process_group: (pg_name, desc) + *note: `pg_name`, `desc` corresponds to `pg_id`, `desc` in groups table + collective_seq_id: ordered id for collective operations and coalesced group operations + p2p_seq_id: ordered id for point-to-point operations + op_id: ordered id including individual ops inside coalescing group + profiling_name: descriptive name of the operation + 'time_created_ns', + 'input_sizes', + 'output_sizes', + 'state', + 'time_discovered_started_ns', + 'time_discovered_completed_ns', + 'retired', + 'frames', + } + ] + } + """ + tracebacks: list[Traceback] = [] + + collectives: list[Collective] = [] + hccl_calls: list[HCCLCall] = [] + + # once we find one mismatch, we stop pairing up collectives since the pairing is possibly incorrect + # instead, just record the remaining ops as HCCLCalls + mismatch = {_groups[g].id: 0 for g in _groups} + MISMATCH_TAIL = 10 + + # For best effort partial analysis. + dumps_ranks = set() + for key in all_entries.keys(): + try: + dumps_ranks.add(int(key)) + except ValueError as e: + raise ValueError(f"Cannot extract rank from '{key}") from e + """ + - it doesn't matter what order I put collectives/hcclops into their table. we can later on re-sort it by start time + - there could be multiple options for the "first" collective to pair up (rank 0,1 might do a bcast while rank 2,3 do a bcast) + - within a group, the first collective must be the same on all ranks in the group, then it can be marked as a + collective and removed + """ + while all_entries: + # we greedily match collectives, starting arbitrarily with the trace from the first rank + # later, if we exhaust the first rank, we continue with the next 'first rank' + rank_iter = iter(all_entries) + first_rank = next(rank_iter) + other_ranks = list(rank_iter) + + if len(all_entries[first_rank]) == 0: + all_entries.pop(first_rank) + continue + + # lets match the first collective! we need to know which ranks are involved, and ensure that this same + # collective is also the first one on those ranks within that group + entries = all_entries[first_rank] + current_entry = entries[0] + + desc = current_entry["process_group"][1] if current_entry["process_group"][1] else "default_pg" + # For db build and logs printing, we want to use the original pg_name, not the hash one. + original_pg_name = current_entry["process_group"][0] + pg_name = _pg_guids[(original_pg_name, first_rank)] + expected_ranks = set(_memberships[pg_name]) + entry_state = EntryState(current_entry, expected_ranks) + match_record = MatchStateRecord( + expected_ranks=expected_ranks, + other_ranks=other_ranks, + entry_state=entry_state, + candidate_ranks={first_rank}, + candidate_idx={}, + found_ranks=set(), + found_idx={}, + errors=set(), + ) + + check_current_entry_match( + all_entries=all_entries, + current_entry=current_entry, + _memberships=_memberships, + pg_data=ProcessGroupData(pg_guids=_pg_guids, pg_name=pg_name, desc=desc, mismatch=mismatch), + match_record=match_record, + ) + + # Use heuristics to decide what type of errors and error messages we should print. + error_analysis( + entry_context=EntryContext(all_entries, current_entry, dumps_ranks, first_rank), + match_record=match_record, + mismatch=mismatch, + version=get_version_detail(version), + pg_name=pg_name, + ) + # at this point there are 3 possibilities + # 1. we found a match on all the ranks that are members of the group + # -> we create a Collective and remove the individual entries from their original lists + if match_record.found_ranks == expected_ranks and mismatch[pg_name] == 0: + collectives.append(match_record.entry_state.to_collective(len(collectives))) + idx_map = {r: match_record.found_idx[r] if r != first_rank else 0 for r in match_record.found_ranks} + hccl_calls.extend( + match_record.entry_state.to_hccl_call(all_entries, idx_map, len(hccl_calls), collectives[-1].id) + ) + + # 2. we found a partial match but some ranks are missing + # 3. we found no match + else: + logger.debug("appending a non-matching collective") + idx_map = {r: match_record.candidate_idx[r] if r != first_rank else 0 for r in match_record.candidate_ranks} + collectives.append( + match_record.entry_state.to_collective( + len(collectives), + errors=match_record.errors, + idx_map=idx_map, + all_entries=all_entries, + ) + ) + hccl_calls.extend(match_record.entry_state.to_hccl_call(all_entries, idx_map, len(hccl_calls), None)) + + if mismatch[pg_name] > MISMATCH_TAIL: + logger.error("Too many mismatches for process_group %s: %s aborting", pg_name, desc) + break + return tracebacks, collectives, hccl_calls + + +def build_db(details: dict[str, dict[str, Any]], args: argparse.Namespace, version: str) -> Database: + if args.verbose: + os.environ["FR_TRACE_VERBOSE_OUTPUT"] = "1" + # temporary state used for building database + entries = {} + pg_config = {} + version_by_ranks = {} + for rank, dump in details.items(): + entries[rank] = dump["entries"] + version_by_ranks[rank] = dump["version"] + pg_config[rank] = dump["pg_config"] + + # Ensure version is consistent across all ranks. + check_version(version_by_ranks, version) + entries = align_trace_from_beginning(entries) + + # flattened database + groups, _groups, memberships, _memberships, _pg_guids = build_groups_memberships(pg_config) + logger.debug("built groups, memberships") + + if not args.allow_incomplete_ranks: + check_no_missing_dump_files(entries, memberships) + + if args.just_print_entries: + just_print_entries(entries, _groups, _memberships, _pg_guids, args) + return None + + tracebacks, collectives, hccl_calls = build_collectives(entries, _groups, _memberships, _pg_guids, version) + logger.debug("built collectives, hccl_calls") + if args.verbose: + logger.debug("Groups") + logger.debug(tabulate(groups, headers=Group._fields)) + logger.debug("Memberships") + logger.debug(tabulate(memberships, headers=Membership._fields)) + logger.debug("Collectives") + logger.debug(tabulate(collectives, headers=Collective._fields)) + logger.debug("HCCLCalls") + logger.debug(tabulate(hccl_calls, headers=HCCLCall._fields)) + db = Database( + tracebacks=tracebacks, + collectives=collectives, + hcclcalls=hccl_calls, + groups=groups, + memberships=memberships, + ) + return db diff --git a/tools/flight_recorder/components/types.py b/tools/flight_recorder/components/types.py new file mode 100644 index 0000000000000000000000000000000000000000..40cb08a1a40522a917dffc1d6dc11ccc7c44a5f4 --- /dev/null +++ b/tools/flight_recorder/components/types.py @@ -0,0 +1,550 @@ +import math +import os +from enum import auto, Enum +from typing import ( + _eval_type, + Any, + Generic, + NamedTuple, + Optional, + TypeVar, +) + +from tools.flight_recorder.components.fr_logger import FlightRecorderLogger + + +T = TypeVar("T", bound=NamedTuple) + + +class Ref(Generic[T]): + pass + + +class TypeInfo(NamedTuple): + name: str + fields: list[tuple[str, type]] # type: ignore[type-arg] + + @classmethod + def from_type(cls, c: T) -> "TypeInfo": + if hasattr(c, "__name__"): + name = c.__name__ + else: + name = str(c) + return cls( + name, + [(f, _eval_type(c.__annotations__[f], globals(), {})) for f in c._fields], + ) + + +class MatchState(Enum): + """ + Enum representing the possible states of matching for collective operations. + + - FULLY_MATCHED: Indicates that all aspects of the collective operations match. + - COLLECTIVE_TYPE_MISMATCH: The types of the collective operations differ. + - SIZE_OR_SYNTAX_MISMATCH: There is a mismatch in input/output sizes or violation of collective syntax. + - COLLECTIVE_STATE_MISMATCH: + The states of the collective not same, such as one finished while another just started or scheduled. + - COLLECTIVE_DTYPE_MISMATCH: The data types of the collective input/output differ. + - UNDECIDED: + The match status is ambiguous or cannot be determined, e.g., we might need to check all ranks for alltoall_base. + """ + + FULLY_MATCHED = auto() + COLLECTIVE_TYPE_MISMATCH = auto() + SIZE_OR_SYNTAX_MISMATCH = auto() + COLLECTIVE_STATE_MISMATCH = auto() + COLLECTIVE_DTYPE_MISMATCH = auto() + UNDECIDED = auto() + + +class MatchInfo: + """ + Aside from the match state, we also store some dynamic info for the match such as the culprit rank + or collective state that caused the mismatch. + """ + + def __init__(self, state: MatchState, culprit: Optional[str] = None) -> None: + self._state = state + self.culprit = culprit + + def __str__(self) -> str: + details = f", {self.culprit}" if getattr(self, "culprit", None) else "" + return f"Error type: {self._state.name}{details}" + + @property + def state(self) -> MatchState: + return self._state + + +class Group(NamedTuple): + id: str + desc: str + size: int + + +class Membership(NamedTuple): + group_id: str + global_rank: int + + +class Traceback(NamedTuple): + id: int + frames: str + + +class Collective(NamedTuple): + id: int + group_id: str + pass_check: bool + collective_seq_id: int + p2p_seq_id: int + record_id: int + pg_desc: str + collective_name: str + input_sizes: list[list[int]] + output_sizes: list[list[int]] + expected_ranks: set[int] + collective_state: str + collective_frames: list[dict[str, str]] + input_numel: Optional[int] = None + output_numel: Optional[int] = None + missing_ranks: Optional[set[int]] = None + mismatch_collectives: Optional[dict[int, "Collective"]] = None + type_of_mismatch: Optional[MatchInfo] = None + + +class HCCLCall(NamedTuple): + id: int + collective_id: Ref[Collective] + group_id: str + global_rank: int # technically Ref[Process] once we have it + traceback_id: Ref[Traceback] + collective_type: str + sizes: list[list[int]] + + +class Database(NamedTuple): + groups: list[Group] + memberships: list[Membership] + tracebacks: list[Traceback] + collectives: list[Collective] + hcclcalls: list[HCCLCall] + + +types = [ + TypeInfo.from_type(t) # type: ignore[type-var] + for t in globals().values() + if (isinstance(t, type) and issubclass(t, tuple) and hasattr(t, "_fields") and t is not TypeInfo) +] + + +COLLECTIVES = { + "broadcast", + "_broadcast_oop", + "reduce", + "_reduce_oop", + "all_gather", + "all_reduce", + "_all_gather_base", + "all_gather_into_tensor_coalesced", + "reduce_scatter", + "reduce_scatter_tensor_coalesced", + "_reduce_scatter_base", + "gather", + "scatter", + "all_to_all", + "all_reduce_barrier", + "allreduce_coalesced", + "ALLGATHER_coalesced", + "REDUCE_SCATTER_coalesced", +} + +P2P = { + "send", + "recv", +} + + +class EntryState: + """ + Util class to keep track of the state of an entry and standardize the way we + log the error info during analysis. + """ + + def __init__(self, entry: dict[str, Any], expected_ranks: set[int]) -> None: + self.pg_name = entry["process_group"][0] + self.desc = entry["process_group"][1] + self.pg_desc = f"{self.pg_name}:{self.desc}" if self.desc != "undefined" else self.pg_name + self.profiling_name = entry["profiling_name"] + self.collective_seq_id = entry["collective_seq_id"] + self.p2p_seq_id = entry["p2p_seq_id"] + self.record_id = entry["record_id"] + self.input_sizes = entry["input_sizes"] + self.output_sizes = entry["output_sizes"] + self.collective_state = entry["state"] + self.collective_frames = entry.get("frames", []) + self.expected_ranks = expected_ranks + self.missing_ranks: set[int] + self.input_numel: int + self.output_numel: int + self.errors: set[tuple[int, MatchInfo]] + + + def log( + self, + logger: FlightRecorderLogger, + logger_msg: str, + frame_formatter: Any, + additional_info: dict = None, + ) -> None: + logger.info( + logger_msg, + self.collective_seq_id, + ) + logger.info("internal record id: %s", self.record_id) + logger.info("group info: %s", self.pg_desc) + logger.info("collective: %s", self.profiling_name) + if additional_info and "missing_ranks" in additional_info: + missing_ranks = additional_info["missing_ranks"] + self.missing_ranks = missing_ranks + logger.info("missing ranks: %s", missing_ranks) + if additional_info and "total_numel" in additional_info: + total_numel = additional_info["total_numel"] + self.input_numel = total_numel[0] + self.output_numel = total_numel[1] + logger.info("total input numel: %d", total_numel[0]) + logger.info("total output numel: %d", total_numel[1]) + logger.info("input sizes: %s", self.input_sizes) + logger.info("output sizes: %s", self.output_sizes) + logger.info("world size: %d", len(self.expected_ranks)) + logger.info("expected ranks: %s", str(self.expected_ranks)) + logger.info("collective state: %s", self.collective_state) + if additional_info and "errors" in additional_info: + errors = additional_info["errors"] + self.errors = errors + error_msg = ", ".join(f"Culprit rank {error[0]}; {str(error[1])}" for error in errors) + logger.info("error msg: %s", error_msg) + logger.info("collective stack trace: \n %s", frame_formatter(self.collective_frames)) + + def to_collective( + self, + collective_id: int, + errors: Optional[set[tuple[int, MatchInfo]]] = None, + idx_map: Optional[dict[int, int]] = None, + all_entries: Optional[dict[int, list[dict[str, Any]]]] = None, + ) -> Collective: + if not errors: + return Collective( + id=collective_id, + group_id=self.pg_name, + record_id=self.record_id, + pg_desc=self.pg_desc, + pass_check=True, + collective_seq_id=self.collective_seq_id, + p2p_seq_id=self.p2p_seq_id, + collective_name=self.profiling_name, + input_sizes=self.input_sizes, + output_sizes=self.output_sizes, + expected_ranks=self.expected_ranks, + collective_state=self.collective_state, + collective_frames=self.collective_frames, + missing_ranks=getattr(self, "missing_ranks", None), + ) + else: + if idx_map is None: + raise ValueError("idx_map cannot be None") + if all_entries is None: + raise ValueError("all_entries cannot be None") + mismatch_collectives = {} + for rank, error in errors: + idx = idx_map[rank] + entry = all_entries[rank][idx] + desc = entry["process_group"][1] + pg_name = entry["process_group"][0] + mismatch_collectives[rank] = Collective( + id=collective_id, + group_id=entry["process_group"][0], + record_id=entry["record_id"], + pg_desc=f"{pg_name}:{desc}" if desc != "undefined" else pg_name, + pass_check=False, + collective_seq_id=entry["collective_seq_id"], + p2p_seq_id=entry["p2p_seq_id"], + collective_name=entry["profiling_name"], + input_sizes=entry["input_sizes"], + output_sizes=entry["output_sizes"], + expected_ranks=self.expected_ranks, + collective_state=entry["state"], + collective_frames=entry.get("frames", []), + type_of_mismatch=error, + ) + return Collective( + id=collective_id, + group_id=self.pg_name, + record_id=self.record_id, + pg_desc=self.pg_desc, + pass_check=False, + collective_seq_id=self.collective_seq_id, + p2p_seq_id=self.p2p_seq_id, + collective_name=self.profiling_name, + input_sizes=self.input_sizes, + output_sizes=self.output_sizes, + expected_ranks=self.expected_ranks, + collective_state=self.collective_state, + collective_frames=self.collective_frames, + input_numel=self.input_numel if hasattr(self, "input_numel") else None, + output_numel=self.output_numel if hasattr(self, "output_numel") else None, + missing_ranks=self.missing_ranks if hasattr(self, "missing_ranks") else None, + mismatch_collectives=mismatch_collectives, + ) + + def to_hccl_call( + self, + all_entries: dict[int, list[dict[str, Any]]], + idx_map: dict[int, int], + hccl_call_id: int, + collective_id: Any, + ) -> list[HCCLCall]: + result = [] + for i, k in idx_map.items(): + all_entries[i].pop(k) + result.append( + HCCLCall( + id=hccl_call_id, + collective_id=collective_id, + group_id=self.pg_name, # type: ignore[arg-type] + global_rank=i, + traceback_id=0, # type: ignore[arg-type] + collective_type=self.profiling_name, + sizes=self.input_sizes, + ) + ) + hccl_call_id += 1 + return result + + +class Op: + """Parses relevant info about operation out of 'event' dict + + examples of supported `profiling_name`s: + hccl:broadcast + hccl:send 1->2 + hccl:recv 3<-0 + """ + MISSING_FRAMES_ERR = "Event missing 'frames' field or empty frames array" + INVALID_FRAME_ERR = "Frame[0] missing 'name' field" + + + def __init__(self, event: dict[Any, Any], memberships: dict[str, set[Any]], pg_name: str): + + frames = event.get("frames") + if not frames: + raise ValueError(self.MISSING_FRAMES_ERR) + first_frame = frames[0] if len(frames) > 0 else None + if not first_frame: + raise ValueError(self.MISSING_FRAMES_ERR) + self.profiling_name = first_frame.get("name") + if self.profiling_name is None: + raise ValueError(self.INVALID_FRAME_ERR) + parts = self.profiling_name.split(":") + self.type = parts[0] + meta = parts[1] if len(parts) == 2 else None + self.state = event.get("state") + self.pg_name, self.pg_desc = event.get("process_group") + if type == "send": + s, d = meta.split("->") + self._src, self._dst = int(s), int(d) + elif type == "recv": + d, s = meta.split("<-") + self._dst, self._src = int(d), int(s) + else: + self._src, self._dst = -1, -1 + self._init_global_src_dst(memberships[pg_name]) + self.pg_size = len(memberships[pg_name]) + if type in P2P | COLLECTIVES: + self.input_sizes = event.get("input_sizes") + self.output_sizes = event.get("output_sizes") + else: + self.input_sizes, self.output_sizes = None, None + self.collective_seq_id = event.get("collective_seq_id") + self.p2p_seq_id = event.get("p2p_seq_id") + self.input_dtypes = event.get("input_dtypes") + self.output_dtypes = event.get("output_dtypes") + self.time_created_ns = event.get("time_created_ns") + self.collective_frames = event.get("frames", []) + self.is_verbose = os.getenv("FR_TRACE_VERBOSE_OUTPUT", "0") == "1" + + def _init_global_src_dst(self, pg_ranks: set[Any]) -> None: + pg_ranks = sorted(pg_ranks) + self._src_g = pg_ranks[self._src] if self._src is not None else None + self._dst_g = pg_ranks[self._dst] if self._dst is not None else None + + @property + def src(self) -> int: + if self.type not in P2P: + raise ValueError(f"Can't get src of non-p2p op (type: {self.type})") + return self._src + + @property + def dst(self) -> int: + if self.type not in P2P: + raise ValueError(f"Can't get dst of non-p2p op (type: {self.type})") + return self._dst + + def __repr__(self) -> str: + p2p_info = "" + if self.type in P2P: + p2p_info = f"s={self._src_g} d={self._dst_g}" + if self.is_verbose: + verbose_info = ( + f"timestamp_created={self.time_created_ns}", + p2p_info, + f"input_sizes={self.input_sizes}", + f"output_sizes={self.output_sizes}", + f"input_dtypes={self.input_dtypes}", + f"output_dtypes={self.output_dtypes}", + "collective_seq_id | p2p_seq_id=" f"{self.p2p_seq_id if self.type in P2P else self.collective_seq_id}", + f"pg_name={self.pg_name}", + f"pg_description={self.pg_desc}", + f"pg_size={self.pg_size}", + f"state={self.state}", + ) + return f"{self.type}({', '.join(s for s in verbose_info if s)})" + return f"{self.type}(%sinput_sizes={self.input_sizes}, state={self.state})" % ( + f"{p2p_info}, " if p2p_info else "" + ) + + def has_different_dtypes_and_non_empty_sizes(self, other): + """ + Check if the input/output dtypes are different and the sizes are non-empty. + """ + # Check if input/output dtypes are different and sizes are non-empty + condition1 = set(self.input_dtypes) != set(self.output_dtypes) and self.input_sizes[0] and self.output_sizes[0] + condition2 = set(self.input_dtypes) != set(other.input_dtypes) and self.input_sizes[0] and other.input_sizes[0] + condition3 = ( + set(self.input_dtypes) != set(other.output_dtypes) and self.input_sizes[0] and other.output_sizes[0] + ) + return condition1 or condition2 or condition3 + + def match(self, other: "Op") -> MatchInfo: + if self.type == "send": + return ( + MatchInfo(MatchState.FULLY_MATCHED) + if ( + other.type == "recv" + and self.src == other.src + and self.dst == other.dst + and self.input_sizes == other.output_sizes + ) + else MatchInfo(MatchState.SIZE_OR_SYNTAX_MISMATCH) + ) + elif self.type == "recv": + return ( + MatchInfo(MatchState.FULLY_MATCHED) + if ( + other.type == "send" + and self.src == other.src + and self.dst == other.dst + and self.output_sizes == other.input_sizes + ) + else MatchInfo(MatchState.SIZE_OR_SYNTAX_MISMATCH) + ) + elif self.type in COLLECTIVES: + if self.type != other.type: + return MatchInfo( + MatchState.COLLECTIVE_TYPE_MISMATCH, + f"Expected collective type: '{self.type}' does not match found collective type: '{other.type}'", + ) + if self.state != other.state: + return MatchInfo( + MatchState.COLLECTIVE_STATE_MISMATCH, + f"Expected state: '{self.state}' does not match found state: '{other.state}'", + ) + if self.has_different_dtypes_and_non_empty_sizes(self): + return MatchInfo( + MatchState.COLLECTIVE_DTYPE_MISMATCH, + f"Expected dtypes: '{set(self.input_dtypes)}' does not " + f"match found dtype: '{set(self.output_dtypes)}/" + f"{set(other.input_dtypes)}/{set(other.output_dtypes)}'", + ) + if self.type == "all_to_all": + return MatchInfo(MatchState.UNDECIDED) + if self.type != "scatter" and self.input_sizes != other.input_sizes: + return MatchInfo( + MatchState.SIZE_OR_SYNTAX_MISMATCH, + f"Expected input sizes: '{self.input_sizes}' does not match found input sizes: " + f"'{other.input_sizes}'", + ) + if self.type != "gather" and self.output_sizes != other.output_sizes: + return MatchInfo( + MatchState.SIZE_OR_SYNTAX_MISMATCH, + f"Expected output sizes: '{self.output_sizes}' does not match found output sizes: " + f"'{other.output_sizes}'", + ) + if self.type in ["all_reduce", "allreduce_coalesced"] and self.input_sizes != other.output_sizes: + return MatchInfo( + MatchState.SIZE_OR_SYNTAX_MISMATCH, + f"Expected input sizes: '{self.input_sizes}' does not match found output sizes: '{other.output_sizes}'", + ) + if self.type in [ + "all_gather", + "all_gather_base", + "all_gather_into_tensor_coalesced", + ] and not (math.prod(other.output_sizes[0]) == math.prod(self.input_sizes[0]) * self.pg_size): + return MatchInfo( + MatchState.SIZE_OR_SYNTAX_MISMATCH, + f"Found input numel '{math.prod(other.input_sizes[0])} * pg size {self.pg_size}' " + f"does not match output numel '{math.prod(other.output_sizes[0])}'", + ) + if self.type in [ + "reduce_scatter", + "_reduce_scatter_base", + "reduce_scatter_tensor_coalesced", + ] and not (math.prod(other.input_sizes[0]) == math.prod(self.output_sizes[0]) * self.pg_size): + return MatchInfo( + MatchState.SIZE_OR_SYNTAX_MISMATCH, + f"Found input numel '{math.prod(other.input_sizes[0])}' does not match output numel " + f"'{math.prod(other.output_sizes[0])} * pg size {self.pg_size}'", + ) + elif self.type in [ + "coalesced", + "ALLGATHER_coalesced", + "REDUCE_SCATTER_coalesced", + ]: + return ( + MatchInfo(MatchState.FULLY_MATCHED) + if (other.type == self.type) + else MatchInfo(MatchState.SIZE_OR_SYNTAX_MISMATCH) + ) + return MatchInfo(MatchState.FULLY_MATCHED) + + +class MatchStateRecord: + def __init__( + self, + expected_ranks: set[int], + other_ranks: list[int], + entry_state: EntryState, + candidate_ranks: set[int], + candidate_idx: dict[int, int], + found_ranks: set[int], + found_idx: dict[int, int], + errors: set[tuple[int, MatchInfo]], + ) -> None: + self.expected_ranks = expected_ranks + self.other_ranks = other_ranks + self.entry_state = entry_state + self.candidate_ranks = candidate_ranks + self.candidate_idx = candidate_idx + self.found_ranks = found_ranks + self.found_idx = found_idx + self.errors = errors + self.has_undecided_case = False + + def reset_for_coalesced(self, entry_state: EntryState, candidate_ranks: set[int]) -> None: + self.entry_state = entry_state + self.candidate_ranks = candidate_ranks + self.candidate_idx = {} + self.found_ranks = set() + self.found_idx = {} + self.errors = set() diff --git a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp index 2702dfd79327abfa87d225ae451aff661a9c3516..b328157f3735390339cf80cdd6bbcf5b300a93d1 100644 --- a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp +++ b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp @@ -27,6 +27,7 @@ #include "third_party/acl/inc/acl/acl_base.h" #include "torch_npu/csrc/aten/CustomFunctions.h" #include "torch_npu/csrc/aten/NPUNativeFunctions.h" +#include "torch_npu/csrc/core/npu/GetCANNInfo.h" #include "torch_npu/csrc/core/npu/NPUFunctions.h" #include "torch_npu/csrc/core/NPUBridge.h" #include "torch_npu/csrc/core/NPUStorageImpl.h" @@ -289,7 +290,16 @@ void getHcclCommConfig(HcclCommConfig* config, bool isP2P = false) } // Temporarily adding this logic to set deterministic states to avoid a known issues within HCCL. - config->hcclDeterministic = getDeterministicState() ? 1 : 0; + static const bool isCannVersionGteBase = []() { + const std::string baseCannversion = "8.2.RC1"; + const std::string baseCannModule = "CANN"; + return IsGteCANNVersion(baseCannversion, baseCannModule); + }(); + if (isCannVersionGteBase) { + config->hcclDeterministic = 0xffffffff; + } else { + config->hcclDeterministic = getDeterministicState() ? 1 : 0; + } // Compatible with the size check of the old version of HCCL, forcibly convert // the config object to a size_t=32 object, and retain the N ± 2 version