#if defined(USE_CUDA) #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include // fuser and IR parser #include #include "torch/csrc/jit/ir/irparser.h" #include #include #include // Tests go in torch::jit namespace torch { namespace jit { using namespace torch::jit::fuser; using namespace torch::jit::fuser; namespace { TensorView* makeContigTensor(int nDims, DataType dtype = DataType::Float) { std::vector dom; for (int i = 0; i < nDims; i++) dom.push_back(new IterDomain(new Int(0), new Int())); std::vector contig(dom.size(), true); return new TensorView(new TensorDomain(dom, contig), dtype); } TensorView* makeDummyTensor(int nDims, DataType dtype = DataType::Float) { // We can uncomment the below statement to test all tests with contiguous // tensors. return makeContigTensor(nDims, dtype); std::vector dom; for (int i = 0; i < nDims; i++) dom.push_back(new IterDomain(new Int(0), new Int())); return new TensorView(new TensorDomain(dom), dtype); } TensorView* makeConcreteTensor( std::vector sizes, DataType dtype = DataType::Float) { // We can uncomment the below statement to test all tests with contiguous // tensors. return makeContigTensor(nDims, dtype); std::vector dom; for (size_t i = 0; i < sizes.size(); i++) dom.push_back(new IterDomain(new Int(0), new Int(sizes[i]))); return new TensorView(new TensorDomain(dom), dtype); } TensorView* makeTensorWithContig( int nDims, std::vector contig_info, DataType dtype = DataType::Float) { std::vector dom; for (int i = 0; i < nDims; i++) dom.push_back(new IterDomain(new Int(0), new Int())); return new TensorView(new TensorDomain(dom, contig_info), dtype); } void checkIntValue( const EvaluationContext* eval_context, Val* val, Int::ScalarType expected_value) { TORCH_CHECK(val->isAnInt()); const auto actual_value = ExpressionEvaluator::evaluate(val, eval_context); TORCH_CHECK(actual_value.has_value()); TORCH_CHECK(actual_value.value() == expected_value); } } // namespace // 1. Test cases are void() functions. // 2. They start with the prefix `test` // A few smoke tests for IrGraphGenerator // (These tests exercise IrGraphGenerator through a non-trivial IR, // to make sure that it runs w/o crashing. The actual output is not // validated) void testGPU_IrGraphGenerator() { Fusion fusion; FusionGuard fg(&fusion); // Make sure we can handle empty IRs TORCH_CHECK(!IrGraphGenerator::toGraphviz( &fusion, IrGraphGenerator::DetailLevel::Basic) .empty()); // Construct an interesting IR TensorView* tv0 = makeDummyTensor(2); fusion.addInput(tv0); TensorView* tv2 = add(tv0, new Float(3.141)); TensorView* tv3 = broadcast(tv0, {false, true, false, true}); TensorView* tv4 = reductionOp(BinaryOpType::Add, {2}, new Float(0), tv3); TensorView* tv5 = clamp(tv4, new Float(0.f), new Float(1.f)); TensorView* tv6 = add(tv2, tv2); // Another checkpoint before adding outputs TORCH_CHECK(!IrGraphGenerator::toGraphviz( &fusion, IrGraphGenerator::DetailLevel::Explicit) .empty()); fusion.addOutput(tv6); tv4->axis(2)->parallelize(ParallelType::BIDy); tv6->merge(0); tv6->split(0, 4); tv6->axis(0)->parallelize(ParallelType::BIDx); tv5->reorder({{-1, 0}}); tv2->computeAt(tv6, 1); // Another checkpoint with more node types TORCH_CHECK(!IrGraphGenerator::toGraphviz( &fusion, IrGraphGenerator::DetailLevel::ComputeOnly) .empty()); for (Val* val : fusion.vals()) { if (!fusion.hasInput(val) && val->getValType().value() == ValType::TensorView) { TensorView* tv = static_cast(val); tv->axis(-1)->parallelize(ParallelType::TIDx); } } // Final IR graph TORCH_CHECK(!IrGraphGenerator::toGraphviz( &fusion, IrGraphGenerator::DetailLevel::Verbose) .empty()); } void testGPU_FusionDispatch() { Fusion fusion; FusionGuard fg(&fusion); Float* f = new Float{2.f}; std::stringstream ss1, ss2, ss3; ss1 << f; ss2 << static_cast(f); ss3 << static_cast(f); TORCH_CHECK( ss1.str().compare(ss2.str()) == 0 && ss1.str().compare(ss3.str()) == 0, "Error with dispatch system where results differ by passing Float* vs Val* vs Statement*."); } // Evaluate basic scalar operations with constant values void testGPU_FusionExprEvalConstants() { Fusion fusion; FusionGuard fg(&fusion); EvaluationContext eval_context(&fusion); auto* a = new Int(7); auto* b = new Int(3); checkIntValue(&eval_context, neg(a), -7); checkIntValue(&eval_context, add(a, b), 10); checkIntValue(&eval_context, neg(mul(sub(a, b), div(a, b))), -8); checkIntValue(&eval_context, mod(a, b), 1); checkIntValue(&eval_context, ceilDiv(a, b), 3); } // Evaluate basic scalar operations with bound values void testGPU_FusionExprEvalBindings() { Fusion fusion; FusionGuard fg(&fusion); EvaluationContext eval_context(&fusion); auto* a = new Int(); auto* b = new Int(); auto* c = add(a, b); auto* d = neg(ceilDiv(c, b)); auto* e = new Int(0); // trying to evaluate before binding should give empty results TORCH_CHECK(!ExpressionEvaluator::evaluate(a, &eval_context).has_value()); TORCH_CHECK(!ExpressionEvaluator::evaluate(d, &eval_context).has_value()); eval_context.bind(a, 7); eval_context.bind(b, 3); // can't bind to the results of expressions ASSERT_ANY_THROW(eval_context.bind(c, 100)); // can't bind to concrete values ASSERT_ANY_THROW(eval_context.bind(e, 100)); checkIntValue(&eval_context, c, 10); checkIntValue(&eval_context, sub(a, b), 4); checkIntValue(&eval_context, mod(a, b), 1); checkIntValue(&eval_context, ceilDiv(a, b), 3); checkIntValue(&eval_context, d, -4); // Reset evaluation context eval_context = EvaluationContext(&fusion); eval_context.bind(a, 2); eval_context.bind(b, 5); checkIntValue(&eval_context, c, 7); checkIntValue(&eval_context, sub(a, b), -3); checkIntValue(&eval_context, mod(a, b), 2); checkIntValue(&eval_context, ceilDiv(a, b), 1); checkIntValue(&eval_context, d, -2); } // Evaluate expressions in a simple IR void testGPU_FusionExprEvalBasic() { Fusion fusion; FusionGuard fg(&fusion); // Create a non-trivial IR TensorView* tv0 = makeDummyTensor(2); TensorView* tv1 = makeDummyTensor(2); fusion.addInput(tv0); fusion.addInput(tv1); TensorView* tv2 = add(tv1, new Float(2.0)); TensorView* tv3 = add(tv0, tv2); fusion.addOutput(tv3); tv3->split(0, 4); tv0->computeAt(tv3, 1); tv1->computeAt(tv3, 1); tv3->axis(0)->parallelize(ParallelType::BIDx); tv2->axis(1)->parallelize(ParallelType::Unroll); tv3->axis(1)->parallelize(ParallelType::Unroll); tv2->axis(-1)->parallelize(ParallelType::TIDx); tv3->axis(-1)->parallelize(ParallelType::TIDx); // 1. Create an evaluation context EvaluationContext eval_context(&fusion); // 2. Bind values // // IMPORTANT: // a. The bindings are only as stable as the Vals are in the fusion graph // b. You must use the original (rootDomain) extents // (ex. `tv0->getRootDomain()[0]->extent()` // instead of `tv0->axis(0)->extent()`) // eval_context.bind(tv0->getRootDomain()[0]->extent(), 6); eval_context.bind(tv0->getRootDomain()[1]->extent(), 128); eval_context.bind(tv1->getRootDomain()[0]->extent(), 6); eval_context.bind(tv1->getRootDomain()[1]->extent(), 128); // 3. Evaluate and check result values TORCH_CHECK(tv2->domain()->nDims() == 3); checkIntValue(&eval_context, tv2->axis(0)->rawExtent(), 2); checkIntValue(&eval_context, tv2->axis(1)->rawExtent(), 4); checkIntValue(&eval_context, tv2->axis(2)->rawExtent(), 128); TORCH_CHECK(tv3->domain()->nDims() == 3); checkIntValue(&eval_context, tv3->axis(0)->rawExtent(), 2); checkIntValue(&eval_context, tv3->axis(1)->rawExtent(), 4); checkIntValue(&eval_context, tv3->axis(2)->rawExtent(), 128); } // Evaluate expressions in a more complex IR void testGPU_FusionExprEvalComplex() { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeDummyTensor(2); fusion.addInput(tv0); TensorView* tv1 = mul(tv0, new Float(-1.0)); TensorView* tv2 = add(tv0, new Float(3.0)); TensorView* tv3 = mul(tv0, new Float(2.0)); TensorView* tv4 = add(tv2, tv1); TensorView* tv5 = add(tv4, tv3); TensorView* tv6 = add(tv0, tv3); fusion.addOutput(tv5); fusion.addOutput(tv6); tv5->reorder({{-1, 0}}); tv6->split(0, 5); tv5->merge(0); // 1. Create an evaluation context EvaluationContext eval_context(&fusion); // 2. Bind values eval_context.bind(tv0->getRootDomain()[0]->extent(), 129); eval_context.bind(tv0->getRootDomain()[1]->extent(), 127); // Evaluate and check extent values TORCH_CHECK(tv0->domain()->nDims() == 2); checkIntValue(&eval_context, tv0->axis(0)->rawExtent(), 129); checkIntValue(&eval_context, tv0->axis(1)->rawExtent(), 127); TORCH_CHECK(tv3->domain()->nDims() == 2); checkIntValue(&eval_context, tv3->axis(0)->rawExtent(), 129); checkIntValue(&eval_context, tv3->axis(1)->rawExtent(), 127); TORCH_CHECK(tv4->domain()->nDims() == 2); checkIntValue(&eval_context, tv4->axis(0)->rawExtent(), 129); checkIntValue(&eval_context, tv4->axis(1)->rawExtent(), 127); TORCH_CHECK(tv5->domain()->nDims() == 1); checkIntValue(&eval_context, tv5->axis(0)->rawExtent(), 16383); TORCH_CHECK(tv6->domain()->nDims() == 3); checkIntValue(&eval_context, tv6->axis(0)->rawExtent(), 26); checkIntValue(&eval_context, tv6->axis(1)->rawExtent(), 5); checkIntValue(&eval_context, tv6->axis(2)->rawExtent(), 127); } // Evaluate expressions post lowering void testGPU_FusionExprEvalPostLower() { Fusion fusion; FusionGuard fg(&fusion); // Create a non-trivial IR TensorView* tv0 = makeDummyTensor(2); TensorView* tv1 = makeDummyTensor(2); fusion.addInput(tv0); fusion.addInput(tv1); TensorView* tv2 = add(tv1, new Float(2.0)); TensorView* tv3 = add(tv0, tv2); fusion.addOutput(tv3); tv3->split(0, 4); tv0->computeAt(tv3, 1); tv1->computeAt(tv3, 1); tv3->axis(0)->parallelize(ParallelType::BIDx); tv2->axis(1)->parallelize(ParallelType::Unroll); tv3->axis(1)->parallelize(ParallelType::Unroll); tv2->axis(-1)->parallelize(ParallelType::TIDx); tv3->axis(-1)->parallelize(ParallelType::TIDx); auto* bid_x = add(tv3->axis(0)->rawExtent(), new Int(0)); auto* tid_x = add(tv3->axis(-1)->rawExtent(), new Int(0)); // Lower GpuLower gpulw(&fusion); std::stringstream kernel; gpulw.printKernel(kernel); // 1. Create an evaluation context EvaluationContext eval_context(&fusion); // 2. Bind values eval_context.bind(tv0->getRootDomain()[0]->extent(), 6); eval_context.bind(tv0->getRootDomain()[1]->extent(), 128); eval_context.bind(tv1->getRootDomain()[0]->extent(), 6); eval_context.bind(tv1->getRootDomain()[1]->extent(), 128); // 3. Evaluate and check result values TORCH_CHECK(tv2->domain()->nDims() == 3); checkIntValue(&eval_context, tv2->axis(0)->rawExtent(), 2); checkIntValue(&eval_context, tv2->axis(1)->rawExtent(), 4); checkIntValue(&eval_context, tv2->axis(2)->rawExtent(), 128); TORCH_CHECK(tv3->domain()->nDims() == 3); checkIntValue(&eval_context, tv3->axis(0)->rawExtent(), 2); checkIntValue(&eval_context, tv3->axis(1)->rawExtent(), 4); checkIntValue(&eval_context, tv3->axis(2)->rawExtent(), 128); checkIntValue(&eval_context, bid_x, 2); checkIntValue(&eval_context, tid_x, 128); } void testGPU_FusionClear() { Fusion fusion; FusionGuard fg(&fusion); // 1. Create a dummy IR { TensorView* tv0 = makeDummyTensor(2); TensorView* tv1 = makeDummyTensor(2); fusion.addInput(tv0); fusion.addInput(tv1); TensorView* tv2 = add(tv1, new Float(2.0)); TensorView* tv3 = add(tv0, tv2); fusion.addOutput(tv3); tv3->split(0, 4); tv0->computeAt(tv3, 1); tv1->computeAt(tv3, 1); tv3->axis(0)->parallelize(ParallelType::BIDx); tv2->axis(1)->parallelize(ParallelType::Unroll); tv3->axis(-1)->parallelize(ParallelType::TIDx); } // 2. Clear the IR fusion.clear(); TORCH_CHECK(fusion.exprs().empty()); TORCH_CHECK(fusion.vals().empty()); TORCH_CHECK(fusion.inputs().empty()); TORCH_CHECK(fusion.outputs().empty()); TORCH_CHECK(!fusion.hasReduction()); TORCH_CHECK(!fusion.hasBlockReduction()); TORCH_CHECK(!fusion.hasGridReduction()); // 3. Rebuild the IR { TensorView* tv0 = makeDummyTensor(3); TensorView* tv1 = makeDummyTensor(3); TensorView* tv2 = add(tv1, new Float(2.0)); TensorView* tv3 = add(tv0, tv2); fusion.addInput(tv0); fusion.addInput(tv1); fusion.addOutput(tv3); // tv3 [i0, i1, i2] tv3->reorder({{0, 2}, {2, 0}}); // tv3 [i2, i1, i0] tv3->split(-1, 4); // tv3 [i2, i1, i0outer, i0inner{4}] tv3->reorder({{2, 0}, {3, 1}, {0, 3}}); // tv3 [i0outer, i0inner{4}, i1, i2] tv0->computeAt(tv3, -1); tv1->computeAt(tv3, -1); tv3->axis(1)->parallelize(ParallelType::BIDx); } auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input1 = at::randn({16, 8, 8}, options); at::Tensor input2 = at::randn_like(input1); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({input1, input2}); at::Tensor tv2_ref = input2 + 2.0; at::Tensor output_ref = input1 + tv2_ref; TORCH_CHECK(output_ref.equal(outputs[0])); } void testGPU_FusionCopy() { Fusion original_fusion; // Create the test IR { FusionGuard fg(&original_fusion); auto tv0 = makeDummyTensor(3); auto tv1 = makeDummyTensor(3); auto tv2 = add(tv1, new Float(2.0)); auto tv3 = sub(add(tv0, mul(tv2, tv2)), tv2); original_fusion.addInput(tv0); original_fusion.addInput(tv1); original_fusion.addOutput(tv3); tv3->reorder({{0, 2}, {2, 0}}); tv3->split(-1, 4); tv3->reorder({{2, 0}, {3, 1}, {0, 3}}); tv0->computeAt(tv3, -1); tv1->computeAt(tv3, -1); tv3->axis(0)->parallelize(ParallelType::BIDx); tv3->axis(-1)->parallelize(ParallelType::TIDx); } // Test copy before lowering Fusion clone = original_fusion; // Compare IR dumps std::stringstream original_ir; std::stringstream clone_ir; original_ir << original_fusion; clone_ir << clone; ASSERT_EQ(original_ir.str(), clone_ir.str()); // Lower original fusion std::stringstream original_kernel; { GpuLower lower(&original_fusion); lower.printKernel(original_kernel); } // Make sure the "before lowering" clone was not mutated // while lowering the original fusion IR std::stringstream before_lowering_ir; before_lowering_ir << clone; ASSERT_EQ(original_ir.str(), before_lowering_ir.str()); // Test copy after lowering (including assignment operator) Fusion before_lowering = clone; clone = original_fusion; // Compare IR dumps std::stringstream original_lowered_ir; std::stringstream clone_lowered_ir; original_lowered_ir << original_fusion; clone_lowered_ir << clone; ASSERT_EQ(original_lowered_ir.str(), clone_lowered_ir.str()); // Lower the "before lowering" and compare kernels std::stringstream clone_kernel; { GpuLower lower(&before_lowering); lower.printKernel(clone_kernel); } ASSERT_EQ(original_kernel.str(), clone_kernel.str()); } void testGPU_FusionMove() { Fusion fusion; // Create the test IR { FusionGuard fg(&fusion); auto tv0 = makeDummyTensor(3); auto tv1 = makeDummyTensor(3); auto tv2 = add(tv1, new Float(2.0)); auto tv3 = sub(add(tv0, mul(tv2, tv2)), tv2); fusion.addInput(tv0); fusion.addInput(tv1); fusion.addOutput(tv3); tv3->reorder({{0, 2}, {2, 0}}); tv3->split(-1, 4); tv3->reorder({{2, 0}, {3, 1}, {0, 3}}); tv0->computeAt(tv3, -1); tv1->computeAt(tv3, -1); tv3->axis(0)->parallelize(ParallelType::BIDx); tv3->axis(-1)->parallelize(ParallelType::TIDx); } std::stringstream original_ir; original_ir << fusion; // Test move before lowering Fusion another_fusion = std::move(fusion); // Check that the original fusion is "empty" // // IMPORTANT: these checks assume knowledge of the internal // implementation of the move operations. General uses // should only assume that the moved-from object is in // a valid, but unspecified state. This is similar to the // standard library containers: // https://en.cppreference.com/w/cpp/utility/move // TORCH_CHECK(fusion.exprs().empty()); TORCH_CHECK(fusion.vals().empty()); TORCH_CHECK(fusion.inputs().empty()); TORCH_CHECK(fusion.outputs().empty()); // clear() has no pre-conditions so it's valid to call on a moved-from object fusion.clear(); // Compare IR dumps std::stringstream another_ir; another_ir << another_fusion; ASSERT_EQ(original_ir.str(), another_ir.str()); // Lower the fusion IR std::stringstream kernel; GpuLower lower(&another_fusion); lower.printKernel(kernel); std::stringstream lowered_ir; lowered_ir << another_fusion; // Test move assignment after lowering fusion = std::move(another_fusion); // Compare IR dumps std::stringstream moved_lowered_ir; moved_lowered_ir << fusion; ASSERT_EQ(lowered_ir.str(), moved_lowered_ir.str()); } void testGPU_FusionSimpleArith() { std::stringstream ss1, ss2; Fusion fusion; FusionGuard fg(&fusion); Float* f1 = new Float(1.f); Float* f2 = new Float{2.f}; Float* f3 = new Float(); // Disrupt the fusion to make sure guard works well { Fusion fusion2; FusionGuard fg(&fusion2); Float* f1 = new Float(1.f); Float* f2 = new Float(2.f); add(f1, f2); ss2 << fusion2; } new BinaryOp(BinaryOpType::Add, f3, f1, f2); ss1 << fusion; TORCH_CHECK( ss1.str().compare(ss2.str()) == 0, "Error where explicit add nodes don't match implicit add nodes."); } void testGPU_FusionSimpleTypePromote() { Fusion fusion; FusionGuard fg(&fusion); Float* f4 = new Float{4.f}; Int* i1 = new Int{3}; auto f5 = add(f4, i1); TORCH_CHECK(f5->getDataType() == DataType::Float); } class ZeroMutator : public OptOutMutator { public: Statement* mutate(Float* f) { if (f->isConst() && *(f->value()) == 1.0) return new Float(0.0); return f; } void mutate(Fusion* f) { OptOutMutator::mutate(f); } }; void testGPU_FusionMutator() { Fusion fusion; FusionGuard fg(&fusion); Float* f4 = new Float{1.f}; Int* i1 = new Int{3}; Val* f5 = add(f4, i1); ZeroMutator mutator; mutator.mutate(&fusion); Val* lhs = static_cast(fusion.origin(f5))->lhs(); TORCH_CHECK( lhs->getValType().value() == ValType::Scalar && lhs->getDataType().value() == DataType::Float); Float* flhs = static_cast(lhs); TORCH_CHECK(flhs->value().value() == 0.f); } void testGPU_FusionRegister() { Fusion fusion; FusionGuard fg(&fusion); Float* v1 = new Float{1.f}; Float* v2 = new Float{2.f}; Val* v3 = binaryOp(BinaryOpType::Add, v1, v2); Val* v4 = binaryOp(BinaryOpType::Add, v1, v2); TORCH_CHECK(v1->name() + 1 == v2->name()); TORCH_CHECK(v2->name() + 1 == v3->name()); TORCH_CHECK(v3->name() + 1 == v4->name()); TORCH_CHECK(fusion.origin(v3)->name() + 1 == fusion.origin(v4)->name()); } // dummy expr with 2 outputs only for toposort test. struct DummyExpr : public Expr { ~DummyExpr() = default; DummyExpr(Val* _outlhs, Val* _outrhs, Val* _lhs, Val* _rhs) : Expr(ExprType::UnaryOp) // Not terribly safe... { addOutput(_outlhs); addOutput(_outrhs); addInput(_lhs); addInput(_rhs); this->name_ = FusionGuard::getCurFusion()->registerExpr(this); } DummyExpr(const DummyExpr& other) = delete; DummyExpr& operator=(const DummyExpr& other) = delete; DummyExpr(DummyExpr&& other) = delete; DummyExpr& operator=(DummyExpr&& other) = delete; }; void testGPU_FusionTopoSort() { Fusion fusion; FusionGuard fg(&fusion); // e0: v3, v2 = dummy(v1, v0) // e1: v4 = add(v3, v2) // e2: v5 = add(v2, v4) // e3: v6 = add(v5, v5) Float* v0 = new Float{1.f}; Float* v1 = new Float{2.f}; Float* v2 = new Float(); Float* v3 = new Float(); Float* v4 = new Float(); Float* v5 = new Float(); Float* v6 = new Float(); Expr* e0 = new DummyExpr(v3, v2, v1, v0); Expr* e1 = new BinaryOp(BinaryOpType::Add, v4, v3, v2); Expr* e2 = new BinaryOp(BinaryOpType::Add, v5, v2, v4); Expr* e3 = new BinaryOp(BinaryOpType::Add, v6, v5, v5); std::vector exprs = fusion.exprs(); TORCH_CHECK(exprs.size() == 4); TORCH_CHECK(exprs[0] == e0); TORCH_CHECK(exprs[1] == e1); TORCH_CHECK(exprs[2] == e2); TORCH_CHECK(exprs[3] == e3); fusion.addOutput(v2); exprs = fusion.exprs(true); TORCH_CHECK(exprs.size() == 1); TORCH_CHECK(exprs[0] == e0); fusion.addOutput(v5); exprs = fusion.exprs(true); TORCH_CHECK(exprs[0] == e0); TORCH_CHECK(exprs[1] == e1); TORCH_CHECK(exprs[2] == e2); fusion.addOutput(v4); exprs = fusion.exprs(true); TORCH_CHECK(exprs[0] == e0); TORCH_CHECK(exprs[1] == e1); TORCH_CHECK(exprs[2] == e2); fusion.addOutput(v3); exprs = fusion.exprs(true); TORCH_CHECK(exprs[0] == e0); TORCH_CHECK(exprs[1] == e1); TORCH_CHECK(exprs[2] == e2); fusion.addOutput(v6); exprs = fusion.exprs(true); TORCH_CHECK(exprs.size() == 4); TORCH_CHECK(exprs[0] == e0); TORCH_CHECK(exprs[1] == e1); TORCH_CHECK(exprs[2] == e2); TORCH_CHECK(exprs[3] == e3); TORCH_CHECK(fusion.origin(v2)->name() == 0); TORCH_CHECK(fusion.origin(v3)->name() == 0); TORCH_CHECK(fusion.origin(v4)->name() == 1); TORCH_CHECK(fusion.origin(v5)->name() == 2); TORCH_CHECK(fusion.origin(v6)->name() == 3); } void testGPU_FusionTensor() { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); Fusion fusion; FusionGuard fg(&fusion); { auto tensor = at::randn({2, 3, 4, 5}, options); auto tensor_type = TensorType::create(tensor); auto fuser_tensor = new TensorView(tensor_type); TORCH_CHECK((int64_t)fuser_tensor->nDims() == tensor.dim()); TORCH_CHECK(fuser_tensor->getDataType().value() == DataType::Float); TORCH_CHECK(fuser_tensor->domain() != nullptr); for (int i = 0; i < static_cast(fuser_tensor->nDims()); i++) { // size 1 dimension are makred as broadcast TORCH_CHECK( fuser_tensor->axis(i)->isBroadcast() == (tensor.sizes()[i] == 1)); // check contiguity information; TORCH_CHECK(fuser_tensor->domain()->contiguity()[i]); } } { auto tensor = at::randn({2, 1, 4}, options); auto tensor_type = TensorType::create(tensor); auto fuser_tensor = new TensorView(tensor_type); TORCH_CHECK((int64_t)fuser_tensor->nDims() == tensor.dim()); TORCH_CHECK(fuser_tensor->getDataType().value() == DataType::Float); TORCH_CHECK(fuser_tensor->domain() != nullptr); for (int i = 0; i < static_cast(fuser_tensor->nDims()); i++) { // size 1 dimension are makred as broadcast TORCH_CHECK( fuser_tensor->axis(i)->isBroadcast() == (tensor.sizes()[i] == 1)); } TORCH_CHECK(fuser_tensor->domain()->contiguity()[2]); // temporary WAR to disable contig & bcast; issue # 230 // TODO: insert the check where broadcast & contiguous cannot be marked // together TORCH_CHECK(!fuser_tensor->domain()->contiguity()[0]); TORCH_CHECK(!fuser_tensor->domain()->contiguity()[1]); } { auto tensor = at::randn({2, 3, 1}, options); auto tensor_type = TensorType::create(tensor); auto fuser_tensor = new TensorView(tensor_type); TORCH_CHECK((int64_t)fuser_tensor->nDims() == tensor.dim()); TORCH_CHECK(fuser_tensor->getDataType().value() == DataType::Float); TORCH_CHECK(fuser_tensor->domain() != nullptr); for (int i = 0; i < static_cast(fuser_tensor->nDims()); i++) { // size 1 dimension are makred as broadcast TORCH_CHECK( fuser_tensor->axis(i)->isBroadcast() == (tensor.sizes()[i] == 1)); } TORCH_CHECK(fuser_tensor->domain()->contiguity()[0]); // temporary WAR to disable contig & bcast; issue # 230 // TODO: insert the check where broadcast & contiguous cannot be marked // together TORCH_CHECK(!fuser_tensor->domain()->contiguity()[1]); TORCH_CHECK(!fuser_tensor->domain()->contiguity()[2]); } // TensorType::create fills stride_properties, which helps us to mark // IterDomain properly // Note: implementation could change, depending on how much we want to invest // in our home-brew contiguity coalescing. For now let's make sure that we // properly test what we are using. { auto tensor = at::randn({4, 4, 4}, options); auto sliced_tensor = tensor.slice(1, 0, -1, 2); auto tensor_type = TensorType::create(sliced_tensor); auto fuser_tensor = new TensorView(tensor_type); TORCH_CHECK((int64_t)fuser_tensor->nDims() == tensor.dim()); TORCH_CHECK(fuser_tensor->getDataType().value() == DataType::Float); TORCH_CHECK(fuser_tensor->domain() != nullptr); for (int i = 0; i < static_cast(fuser_tensor->nDims()); i++) { // size 1 dimension are makred as broadcast TORCH_CHECK(fuser_tensor->axis(i)->isBroadcast() == false); } TORCH_CHECK(fuser_tensor->domain()->contiguity()[0]); TORCH_CHECK(!fuser_tensor->domain()->contiguity()[1]); TORCH_CHECK(fuser_tensor->domain()->contiguity()[2]); } { auto tensor = at::randn({2, 3, 4, 5}, options); auto permuted_tensor = tensor.permute({0, 3, 1, 2}); auto tensor_type = TensorType::create(permuted_tensor); auto fuser_tensor = new TensorView(tensor_type); TORCH_CHECK((int64_t)fuser_tensor->nDims() == tensor.dim()); TORCH_CHECK(fuser_tensor->getDataType().value() == DataType::Float); TORCH_CHECK(fuser_tensor->domain() != nullptr); for (int i = 0; i < static_cast(fuser_tensor->nDims()); i++) { // size 1 dimension are makred as broadcast TORCH_CHECK(fuser_tensor->axis(i)->isBroadcast() == false); } TORCH_CHECK(!fuser_tensor->domain()->contiguity()[0]); TORCH_CHECK(!fuser_tensor->domain()->contiguity()[1]); TORCH_CHECK(fuser_tensor->domain()->contiguity()[2]); TORCH_CHECK(!fuser_tensor->domain()->contiguity()[3]); } } void testGPU_FusionFilterVals() { Fusion fusion; FusionGuard fg(&fusion); auto tv0 = makeDummyTensor(1); auto tv1 = makeDummyTensor(1); auto scalar0 = new Float(0); auto scalar1 = new Int(0); auto scalar2 = new Int(1); const std::vector vals = {tv0, scalar0, tv1, scalar1, scalar2}; std::vector tvs( ir_utils::filterByType(vals).begin(), ir_utils::filterByType(vals).end()); TORCH_CHECK(tvs.size() == 2); TORCH_CHECK(tvs[0] == tv0); TORCH_CHECK(tvs[1] == tv1); std::vector floats( ir_utils::filterByType(vals).begin(), ir_utils::filterByType(vals).end()); TORCH_CHECK(floats.size() == 1); TORCH_CHECK(floats[0] == scalar0); std::vector ints( ir_utils::filterByType(vals).begin(), ir_utils::filterByType(vals).end()); TORCH_CHECK(ints.size() == 2); TORCH_CHECK(ints[0] == scalar1); TORCH_CHECK(ints[1] == scalar2); TORCH_CHECK( ir_utils::filterByType(vals).begin() == ir_utils::filterByType(vals).end(), "Not expecting any results"); } void testGPU_FusionTVSplit() { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv = makeDummyTensor(3); tv = tv->split(2, 2); TORCH_CHECK(tv->nDims() == 4); Expr* outer = tv->axis(2)->extent()->getOrigin(); TORCH_CHECK( outer->getExprType().value() == ExprType::BinaryOp && static_cast(outer)->getBinaryOpType() == BinaryOpType::CeilDiv && static_cast(outer)->lhs()->sameAs( tv->getRootDomain()[2]->extent()) && static_cast(static_cast(outer)->rhs()) ->sameAs(new Int(2))); IterDomain* inner = static_cast(tv->axis(3)); TORCH_CHECK( inner->extent()->isScalar() && static_cast(inner->extent())->isConst() && static_cast(inner->extent())->value().value() == 2); } void testGPU_FusionTVMerge() { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv = makeDummyTensor(3); tv = tv->merge(1); Expr* axisOp = tv->axis(1)->extent()->getOrigin(); TORCH_CHECK( tv->nDims() == 2 && axisOp->getExprType() == ExprType::BinaryOp && static_cast(axisOp)->getBinaryOpType() == BinaryOpType::Mul && static_cast(axisOp)->lhs() == tv->getRootDomain()[1]->extent() && static_cast(axisOp)->rhs() == tv->getRootDomain()[2]->extent()); } void testGPU_FusionTVReorder() { Fusion fusion; FusionGuard fg(&fusion); std::unordered_map shift_right{{-1, 0}}; std::unordered_map shift_left{{0, -1}}; std::unordered_map shift_left_2{{0, -1}, {1, 0}, {2, 1}}; std::unordered_map swap{{0, 2}, {2, 0}}; auto tv = makeDummyTensor(3); std::vector ref; ref = std::vector( tv->domain()->domain().begin(), tv->domain()->domain().end()); tv->reorder(shift_left); for (int i = 0; i < (int)tv->nDims(); i++) TORCH_CHECK(ref[i]->sameAs(tv->axis(i - 1))); tv = makeDummyTensor(3); ref = std::vector( tv->domain()->domain().begin(), tv->domain()->domain().end()); tv->reorder(shift_left); for (int i = 0; i < (int)tv->nDims(); i++) TORCH_CHECK(ref[i]->sameAs(tv->axis(i - 1))); tv = makeDummyTensor(3); ref = std::vector( tv->domain()->domain().begin(), tv->domain()->domain().end()); tv->reorder(shift_right); TORCH_CHECK(ref[ref.size() - 1]->sameAs(tv->axis(0))); for (int i = 1; i < (int)tv->nDims(); i++) TORCH_CHECK(ref[i - 1]->sameAs(tv->axis(i))); tv = makeDummyTensor(3); ref = std::vector( tv->domain()->domain().begin(), tv->domain()->domain().end()); tv->reorder(swap); TORCH_CHECK(ref[0]->sameAs(tv->axis(2))); TORCH_CHECK(ref[2]->sameAs(tv->axis(0))); TORCH_CHECK(ref[1]->sameAs(tv->axis(1))); } void testGPU_FusionEquality() { Fusion fusion; FusionGuard fg(&fusion); Float* fval1 = new Float(); Float* fval1_copy = fval1; Float* fval2 = new Float(); Float* fone = new Float(1.0); TORCH_CHECK(fval1->sameAs(fval1_copy)); TORCH_CHECK(!fval1->sameAs(fval2)); TORCH_CHECK(!fone->sameAs(fval1)); TORCH_CHECK(fone->sameAs(new Float(1.0))); Int* ival1 = new Int(); Int* ival1_copy = ival1; Int* ival2 = new Int(); Int* ione = new Int(1); TORCH_CHECK(ival1->sameAs(ival1_copy)); TORCH_CHECK(!ival1->sameAs(ival2)); TORCH_CHECK(!ione->sameAs(ival1)); TORCH_CHECK(ione->sameAs(new Int(1))); BinaryOp* add1 = new BinaryOp(BinaryOpType::Add, new Float(), fval1, ival1); BinaryOp* add1_copy = new BinaryOp(BinaryOpType::Add, new Float(), fval1, ival1); BinaryOp* sub1 = new BinaryOp(BinaryOpType::Sub, new Float(), fval1, ival1); UnaryOp* neg1 = new UnaryOp(UnaryOpType::Neg, new Float(), fval1); UnaryOp* neg2 = new UnaryOp(UnaryOpType::Neg, new Float(), fval2); UnaryOp* neg1_copy = new UnaryOp(UnaryOpType::Neg, new Float(), fval1); TORCH_CHECK(add1->sameAs(add1_copy)); TORCH_CHECK(!add1->sameAs(sub1)); TORCH_CHECK(neg1->sameAs(neg1_copy)); TORCH_CHECK(!static_cast(neg1)->sameAs(add1)); TORCH_CHECK(!neg1->sameAs(neg2)); } void testGPU_FusionDependency() { Fusion fusion; FusionGuard fg(&fusion); Float* f0 = new Float(0.f); Float* f1 = new Float(1.f); auto f2 = add(f0, f1); auto f3 = add(f2, f2); Float* f4 = new Float(4.f); Float* f5 = new Float(5.f); auto f6 = add(f4, f5); Float* f7 = new Float(7.f); Float* f8 = new Float(8.f); auto f9 = add(f7, f8); auto f10 = add(f6, f9); auto f11 = add(f3, f10); TORCH_CHECK(DependencyCheck::isDependencyOf(f0, f11)); TORCH_CHECK(DependencyCheck::isDependencyOf(f1, f11)); TORCH_CHECK(DependencyCheck::isDependencyOf(f2, f11)); TORCH_CHECK(DependencyCheck::isDependencyOf(f3, f11)); TORCH_CHECK(DependencyCheck::isDependencyOf(f6, f11)); TORCH_CHECK(DependencyCheck::isDependencyOf(f9, f11)); TORCH_CHECK(DependencyCheck::isDependencyOf(f0, f2)); TORCH_CHECK(DependencyCheck::isDependencyOf(f2, f3)); TORCH_CHECK(DependencyCheck::isDependencyOf(f4, f6)); TORCH_CHECK(DependencyCheck::isDependencyOf(f8, f10)); TORCH_CHECK(!DependencyCheck::isDependencyOf(f11, f0)); TORCH_CHECK(!DependencyCheck::isDependencyOf(f11, f1)); TORCH_CHECK(!DependencyCheck::isDependencyOf(f11, f2)); TORCH_CHECK(!DependencyCheck::isDependencyOf(f11, f3)); TORCH_CHECK(!DependencyCheck::isDependencyOf(f11, f4)); TORCH_CHECK(!DependencyCheck::isDependencyOf(f11, f5)); TORCH_CHECK(!DependencyCheck::isDependencyOf(f2, f0)); TORCH_CHECK(!DependencyCheck::isDependencyOf(f3, f2)); TORCH_CHECK(!DependencyCheck::isDependencyOf(f6, f4)); TORCH_CHECK(!DependencyCheck::isDependencyOf(f10, f8)); auto dep_chain = DependencyCheck::getSingleDependencyChain(f0, f11); TORCH_CHECK(dep_chain.back() == f11); dep_chain.pop_back(); TORCH_CHECK(dep_chain.back() == f3); dep_chain.pop_back(); TORCH_CHECK(dep_chain.back() == f2); dep_chain.pop_back(); dep_chain = DependencyCheck::getSingleDependencyChain(f6, f11); TORCH_CHECK(dep_chain.back() == f11); dep_chain.pop_back(); TORCH_CHECK(dep_chain.back() == f10); dep_chain.pop_back(); dep_chain = DependencyCheck::getSingleDependencyChain(f4, f11); TORCH_CHECK(dep_chain.back() == f11); dep_chain.pop_back(); TORCH_CHECK(dep_chain.back() == f10); dep_chain.pop_back(); TORCH_CHECK(dep_chain.back() == f6); dep_chain.pop_back(); dep_chain = DependencyCheck::getSingleDependencyChain(f11, f2); TORCH_CHECK(dep_chain.empty()); } void testGPU_FusionParser() { auto g = std::make_shared(); const auto graph0_string = R"IR( graph(%0 : Float(2:1), %1 : Float(2:1)): %c0 : Float(2:1) = aten::mul(%0, %1) %d0 : Float(2:1) = aten::mul(%c0, %0) return (%d0))IR"; torch::jit::parseIR(graph0_string, g.get()); // strides are not yet supported in the irparser. for (auto val : g->block()->inputs()) { if (val->isCompleteTensor()) val->setType(val->type()->cast()->contiguous()); } for (auto node : g->block()->nodes()) { for (auto val : node->outputs()) { if (val->isCompleteTensor()) val->setType(val->type()->cast()->contiguous()); } } auto fusion = fuser::cuda::parseJitIR(g); FusionGuard fg(fusion.get()); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input1 = at::randn({16}, options); at::Tensor input2 = at::randn({16}, options); fuser::cuda::scheduleFusion(fusion.get(), {input1, input2}); // CONSIDER: // 1. this can be moved to a dedicated "golden" file // 2. use a fuzzy compare (ignore non-significant whitespaces for example) const std::string expected_kernel = R"( __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T1, Tensor T3){ float T2[4]; if ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) < T0.size[0] ) ) { for(size_t i6 = 0; i6 < 4; ++i6 ) { T2[ i6 ] = T0[ ( ( ( ( blockIdx.x * 4 ) + i6 ) * 128 ) + threadIdx.x ) ] * T1[ ( ( ( ( blockIdx.x * 4 ) + i6 ) * 128 ) + threadIdx.x ) ]; } } else { for(size_t i6 = 0; i6 < 4; ++i6 ) { if ( ( ( ( ( ( blockIdx.x * 4 ) + i6 ) * 128 ) + threadIdx.x ) < T0.size[0] ) ) { T2[ i6 ] = T0[ ( ( ( ( blockIdx.x * 4 ) + i6 ) * 128 ) + threadIdx.x ) ] * T1[ ( ( ( ( blockIdx.x * 4 ) + i6 ) * 128 ) + threadIdx.x ) ]; } } } if ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) < T0.size[0] ) ) { for(size_t i13 = 0; i13 < 4; ++i13 ) { T3[ ( ( ( ( blockIdx.x * 4 ) + i13 ) * 128 ) + threadIdx.x ) ] = T2[ i13 ] * T0[ ( ( ( ( blockIdx.x * 4 ) + i13 ) * 128 ) + threadIdx.x ) ]; } } else { for(size_t i13 = 0; i13 < 4; ++i13 ) { if ( ( ( ( ( ( blockIdx.x * 4 ) + i13 ) * 128 ) + threadIdx.x ) < T0.size[0] ) ) { T3[ ( ( ( ( blockIdx.x * 4 ) + i13 ) * 128 ) + threadIdx.x ) ] = T2[ i13 ] * T0[ ( ( ( ( blockIdx.x * 4 ) + i13 ) * 128 ) + threadIdx.x ) ]; } } } } )"; std::string actual_kernel = GpuLower(fusion.get()).getKernel(); actual_kernel = "\n" + actual_kernel; if (expected_kernel.size() != actual_kernel.size() || expected_kernel.compare(actual_kernel) != 0) { std::cerr << " Codegen mismatch, codegen possibly changed, or is incorrect. " << " \n ========= EXPECTED ========= \n" << expected_kernel << "\n========= ACTUAL ========== \n" << actual_kernel << "\n=================" << std::endl; TORCH_CHECK(false); } cuda::FusionExecutor fe; fe.compileFusion(fusion.get()); auto outputs = fe.runFusion({input1, input2}); at::Tensor output_ref = input1 * input2 * input1; TORCH_CHECK(output_ref.equal(outputs[0])); } void testGPU_FusionForLoop() { // TODO(kir): re-enable this test // due to the current "GpuLower guard" approach, we can only create // kernel IR during GpuLower::lower() #if 0 Fusion fusion; FusionGuard fg(&fusion); const auto TV0 = new TensorView( new TensorDomain({new IterDomain(new Int(0), new Int(16))}), DataType::Float); const auto TV1 = new TensorView( new TensorDomain({new IterDomain(new Int(0), new Int(16))}), DataType::Float); fusion.addInput(TV0); fusion.addInput(TV1); auto ID0 = new kir::IterDomain(new IterDomain(new Int(0), new Int(8))); TensorView* TV2 = add(TV0, TV1); BinaryOp* op = static_cast(TV2->getOrigin()); fusion.addOutput(TV2); auto fl = new kir::ForLoop(new kir::Int(c10::nullopt), ID0, {op}); std::stringstream result; std::stringstream ref; result << fl; ref << "for(size_t i3{0}; i3 < iS{8}; ++i3 ) {\nT2[ iS{16} ] = T0[ iS{16} ] + T1[ iS{16} ]\n}"; if (result.str().compare(ref.str()) == 0) { std::stringstream err_msg; err_msg << "ForLoop printing has changed or something has gone wrong. " << result.str() << "\n does not match reference: " << ref.str() << std::endl; TORCH_CHECK(false, err_msg.str()); } #endif } void testGPU_FusionCodeGen() { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeDummyTensor(3); new BinaryOp(BinaryOpType::Add, tv0, new Float(0.0), new Float(1.0)); TensorView* tv1 = add(tv0, new Float(2.0)); TensorView* tv2 = add(tv1, new Float(3.0)); fusion.addOutput(tv2); //[I0, I1, I2] tv2 = tv2->split(0, 4); //[I0o, I0i{4}, I1, I2] tv2 = tv2->merge(1); //[I0o, I0i{4}*I1, I2] tv2 = tv2->split(-1, 2); //[I0o, I0i{4}*I1, I2o, I2i{2}] tv2 = tv2->reorder({{0, 1}, {1, 0}, {3, 2}}); //[I0i{4}*I1, I0o, I2i{2}, I2o] tv0->computeAt(tv2, -1); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor output = at::empty({16, 8, 8}, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({}, {output}); at::Tensor output_ref = at::zeros_like(output, options); output_ref = output_ref + 0.0 + 1.0 + 2.0 + 3.0; TORCH_CHECK(output_ref.equal(output)); } void testGPU_FusionCodeGen2() { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeDummyTensor(3); TensorView* tv1 = makeDummyTensor(3); TensorView* tv2 = add(tv1, new Float(2.0)); TensorView* tv3 = add(tv0, tv2); fusion.addInput(tv0); fusion.addInput(tv1); fusion.addOutput(tv3); //[I0, I1, I2] tv3->reorder({{0, 2}, {2, 0}}); //[I2, I1, I0] tv3->split(-1, 4); //[I2, I1, I0o, I0i{4}] tv3->reorder({{2, 0}, {3, 1}, {0, 3}}); // I0o, I0i{4}, I1, I2] tv0->computeAt(tv3, -1); tv1->computeAt(tv3, -1); tv3->axis(0)->parallelize(ParallelType::BIDx); tv3->axis(-1)->parallelize(ParallelType::TIDx); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input1 = at::randn({16, 8, 8}, options); at::Tensor input2 = at::randn_like(input1); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({input1, input2}); at::Tensor tv2_ref = input2 + 2.0; at::Tensor output_ref = input1 + tv2_ref; TORCH_CHECK(output_ref.equal(outputs[0])); } void testGPU_FusionSimplePWise() { Fusion fusion; FusionGuard fg(&fusion); // dimensionality of the problem int nDims = 3; // Set up your input tensor views TensorView* tv0 = makeContigTensor(nDims); TensorView* tv1 = makeContigTensor(nDims); // Register your inputs fusion.addInput(tv0); fusion.addInput(tv1); // Do math with it, it returns a `Val*` but can be static_casted back to // TensorView TensorView* tv2 = add(tv1, new Float(2.0)); TensorView* tv3 = add(tv0, tv2); // Register your outputs fusion.addOutput(tv3); // Do transformations, remember, transformations are outputs to inputs // This doesn't have to be in this order tv3->merge(1); tv3->merge(0); // Split by n_threads tv3->split(0, 128); tv3->split(0, 4); // For all inputs, computeAt the output inline, temporaries should be squeezed // between them tv0->computeAt(tv3, -1); tv1->computeAt(tv3, -1); // Parallelize TV3 tv3->axis(0)->parallelize(ParallelType::BIDx); tv3->axis(-2)->parallelize(ParallelType::Unroll); tv3->axis(-1)->parallelize(ParallelType::TIDx); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input1 = at::randn({64, 2, 128}, options); at::Tensor input2 = at::rand_like(input1); at::Tensor output = at::empty_like(input1); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({input1, input2}, {output}); at::Tensor tv2_ref = input2 + 2.0; at::Tensor output_ref = input1 + tv2_ref; TORCH_CHECK(output_ref.equal(output)); } void testGPU_FusionExecKernel() { Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); TensorView* tv1 = makeDummyTensor(2); // Register your inputs fusion.addInput(tv0); fusion.addInput(tv1); // Do math with it, it returns a `Val*` but can be static_casted back to // TensorView TensorView* tv2 = add(tv1, new Float(2.0)); TensorView* tv3 = add(tv0, tv2); // Register your outputs fusion.addOutput(tv3); tv3->merge(0); tv3->split(0, 128); tv3->split(0, 4); // For all inputs, computeAt the output inline, temporaries should be squeezed // between them tv0->computeAt(tv3, 1); tv1->computeAt(tv3, 1); // Parallelize TV3 tv3->axis(0)->parallelize(ParallelType::BIDx); tv2->axis(1)->parallelize(ParallelType::Unroll); tv3->axis(1)->parallelize(ParallelType::Unroll); tv2->axis(-1)->parallelize(ParallelType::TIDx); tv3->axis(-1)->parallelize(ParallelType::TIDx); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input1 = at::ones({1, 128}, options); at::Tensor input2 = at::ones_like(input1); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({input1, input2}); at::Tensor check = at::full({1, 128}, 4, options); ; TORCH_CHECK(outputs[0].equal(check)); } int ceilDiv_(int a, int b) { return (a + b - 1) / b; } void testGPU_FusionAdvancedComputeAt() { // Case 1 // tv1 = tv0 * 0.5 // tv2 = tv1 * -1 // tv3 = tv1 + 3 // tv4 = tv1 * 2 // tv5 = tv3 + tv2 // tv6 = tv5 + tv4 // tv7 = tv1 + tv4 { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeDummyTensor(2); fusion.addInput(tv0); TensorView* tv1 = mul(tv0, new Float(0.5)); TensorView* tv2 = mul(tv1, new Float(-1.0)); TensorView* tv3 = add(tv1, new Float(3.0)); TensorView* tv4 = mul(tv1, new Float(2.0)); TensorView* tv5 = add(tv3, tv2); TensorView* tv6 = add(tv5, tv4); TensorView* tv7 = add(tv1, tv4); fusion.addOutput(tv6); fusion.addOutput(tv7); // Lets setup to actually run tv7->merge(0); tv7->split(0, 128); tv7->split(0, 4); tv7->axis(0)->parallelize(ParallelType::BIDx); tv0->computeAt(tv7, 1); TORCH_CHECK(tv1->hasComputeAt() && tv1->nDims() == 3); TORCH_CHECK(tv2->getComputeAtView() == tv5 && tv2->nDims() == 3); TORCH_CHECK(tv3->getComputeAtView() == tv5 && tv3->nDims() == 3); TORCH_CHECK(tv4->hasComputeAt() && tv4->nDims() == 3); TORCH_CHECK(tv5->getComputeAtView() == tv6 && tv5->nDims() == 3); TORCH_CHECK(tv6->getComputeAtView() == tv7 && tv6->nDims() == 3); TORCH_CHECK(!tv7->hasComputeAt()); for (Val* val : fusion.vals()) { if (!fusion.hasInput(val) && val->getValType().value() == ValType::TensorView) { TensorView* tv = static_cast(val); tv->axis(1)->parallelize(ParallelType::Unroll); tv->axis(-1)->parallelize(ParallelType::TIDx); } } auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({129, 127}, options); auto t1 = t0.mul({0.5}); auto t2 = t1.mul({-1.0}); auto t3 = t1.add({3.0}); auto t4 = t1.mul({2.0}); auto t5 = t3.add(t2); auto t6 = t5.add(t4); auto t7 = t1.add(t4); at::Tensor kernel_tv6 = at::empty_like(t0, options); at::Tensor kernel_tv7 = at::empty_like(t0, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({t0}, {kernel_tv6, kernel_tv7}); TORCH_CHECK(at::allclose(kernel_tv6, t6)); TORCH_CHECK(at::allclose(kernel_tv7, t7)); } // Case 2 // tv1 = tv0 * -1 // tv2 = tv0 + 3 // tv3 = tv0 * 2 // tv4 = tv2 + tv1 // tv5 = tv4 + tv3 // tv6 = tv5 + tv3 { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeDummyTensor(2); fusion.addInput(tv0); TensorView* tv1 = mul(tv0, new Float(-1.0)); TensorView* tv2 = add(tv0, new Float(3.0)); TensorView* tv3 = mul(tv0, new Float(2.0)); TensorView* tv4 = add(tv2, tv1); TensorView* tv5 = add(tv4, tv3); TensorView* tv6 = add(tv5, tv3); fusion.addOutput(tv5); fusion.addOutput(tv6); // Lets setup to actually run tv6->merge(0); tv6->split(0, 128); tv6->split(0, 4); tv6->axis(0)->parallelize(ParallelType::BIDx); tv0->computeAt(tv6, 1); for (Val* val : fusion.vals()) { if (!fusion.hasInput(val) && val->getValType().value() == ValType::TensorView) { TensorView* tv = static_cast(val); tv->axis(1)->parallelize(ParallelType::Unroll); tv->axis(-1)->parallelize(ParallelType::TIDx); } } auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({129, 127}, options); auto t1 = t0.mul({-1.0}); auto t2 = t0.add({3.0}); auto t3 = t0.mul({2.0}); auto t4 = t2.add(t1); auto t5 = t4.add(t3); auto t6 = t5.add(t3); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({t0}); GpuLower gpulw(&fusion); std::stringstream actual_kernel; gpulw.printKernel(actual_kernel); TORCH_CHECK(at::allclose(outputs[0], t5), actual_kernel.str()); TORCH_CHECK(at::allclose(outputs[1], t6)); } // Case 3 // T2 = T1 * 0.979361 // T3 = T2 * T0 { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeDummyTensor(4); fusion.addInput(tv0); TensorView* tv1 = makeDummyTensor(4); fusion.addInput(tv1); TensorView* tv2 = mul(tv1, new Float(.979361)); TensorView* tv3 = mul(tv2, tv0); fusion.addOutput(tv3); // Lets setup to actually run while (tv3->nDims() > 1) tv3->merge(0); tv3->split(0, 128); tv3->split(0, 4); tv0->computeAt(tv3, 1); tv1->computeAt(tv3, 1); tv3->axis(0)->parallelize(ParallelType::BIDx); for (Val* val : fusion.vals()) { if (!fusion.hasInput(val) && val->getValType().value() == ValType::TensorView) { TensorView* tv = static_cast(val); tv->axis(1)->parallelize(ParallelType::Unroll); tv->axis(-1)->parallelize(ParallelType::TIDx); } } auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({129, 127, 63, 65}, options); at::Tensor t1 = at::rand_like(t0, options); auto t2 = t1.mul({0.979361}); auto t3 = t2.mul(t0); at::Tensor kernel_tv3 = at::empty_like(t0, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({t0, t1}, {kernel_tv3}); GpuLower gpulw(&fusion); std::stringstream actual_kernel; gpulw.printKernel(actual_kernel); TORCH_CHECK(at::allclose(kernel_tv3, t3), actual_kernel.str()); } // Case 4 // T4 = T2 - T3 // T5 = T1 + T4 // T6 = T5 - T0 { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeDummyTensor(4); fusion.addInput(tv0); TensorView* tv1 = makeDummyTensor(4); fusion.addInput(tv1); TensorView* tv2 = makeDummyTensor(4); fusion.addInput(tv2); TensorView* tv3 = makeDummyTensor(4); fusion.addInput(tv3); TensorView* tv4 = sub(tv2, tv3); TensorView* tv5 = add(tv1, tv4); TensorView* tv6 = sub(tv5, tv0); fusion.addOutput(tv6); // Lets setup to actually run while (tv6->nDims() > 1) tv6->merge(0); tv6->split(0, 128); tv6->split(0, 4); tv0->computeAt(tv6, 1); tv1->computeAt(tv6, 1); tv2->computeAt(tv6, 1); tv3->computeAt(tv6, 1); tv6->axis(0)->parallelize(ParallelType::BIDx); for (Val* val : fusion.vals()) { if (!fusion.hasInput(val) && val->getValType().value() == ValType::TensorView) { TensorView* tv = static_cast(val); tv->axis(1)->parallelize(ParallelType::Unroll); tv->axis(-1)->parallelize(ParallelType::TIDx); } } auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({129, 127, 63, 65}, options); at::Tensor t1 = at::rand_like(t0, options); at::Tensor t2 = at::rand_like(t0, options); at::Tensor t3 = at::rand_like(t0, options); auto t4 = t2.sub(t3); auto t5 = t1.add(t4); auto t6 = t5.sub(t0); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({t0, t1, t2, t3}); GpuLower gpulw(&fusion); std::stringstream actual_kernel; gpulw.printKernel(actual_kernel); TORCH_CHECK(at::allclose(outputs[0], t6), actual_kernel.str()); } // Case 5 // tv2 = tv0 + 2.0 // tv3 = tv1 * tv2 { Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); fusion.addInput(tv0); TensorView* tv1 = makeDummyTensor(2); fusion.addInput(tv1); TensorView* tv2 = add(tv0, new Float(2.0)); TensorView* tv3 = mul(tv1, tv2); fusion.addOutput(tv3); tv3->merge(0); tv3->split(-1, 8); tv3->split(-1, 4); tv2->computeAt(tv3, 1); tv2->split(-1, 4); // Kernel will break without this split tv3->axis(0)->parallelize(ParallelType::BIDx); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({63, 65}, options); at::Tensor t1 = at::rand_like(t0, options); auto t2 = t0.add(2.0); auto t3 = t1.mul(t2); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({t0, t1}); TORCH_CHECK(at::allclose(outputs[0], t3)); } } void testGPU_FusionScalarInputs() { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeDummyTensor(2); fusion.addInput(tv0); TensorView* tv1 = makeDummyTensor(2); fusion.addInput(tv1); Float* f0 = new Float(); fusion.addInput(f0); Float* f1 = new Float(); fusion.addInput(f1); Float* f2 = new Float(); fusion.addInput(f2); Float* f3 = new Float(); fusion.addInput(f3); Val* f4 = mul(f0, f1); Val* f5 = sub(f2, f3); TensorView* tv2 = sub(tv1, f4); TensorView* tv3 = add(tv0, f5); TensorView* tv4 = mul(tv3, tv2); fusion.addOutput(tv4); // Lets setup to actually run while (tv4->nDims() > 1) tv4->merge(0); tv4->split(0, 128); tv4->split(0, 4); tv0->computeAt(tv4, 1); tv1->computeAt(tv4, 1); tv4->axis(0)->parallelize(ParallelType::BIDx); for (Val* val : fusion.vals()) { if (!fusion.hasInput(val) && val->getValType().value() == ValType::TensorView) { TensorView* tv = static_cast(val); tv->axis(1)->parallelize(ParallelType::Unroll); tv->axis(-1)->parallelize(ParallelType::TIDx); } } // f4 = f0 * f1 // f5 = f2 - f3 // t2 = t1 - f4 // t3 = t0 + f5 // t4 = t3 * t2 auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); float fl0 = 0.1; float fl1 = -0.2; float fl2 = 0.3; float fl3 = -0.4; float fl4 = fl0 * fl1; float fl5 = fl2 - fl3; at::Tensor t0 = at::randn({129, 127}, options); at::Tensor t1 = at::rand_like(t0, options); auto t2 = t1.sub(fl4); auto t3 = t0.add(fl5); auto t4 = t3.mul(t2); at::Tensor kernel_tv4 = at::empty_like(t0, options); at::Scalar test(fl0); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion( {t0, t1, at::Scalar(fl0), at::Scalar(fl1), at::Scalar(fl2), at::Scalar(fl3)}, {kernel_tv4}); GpuLower gpulw(&fusion); std::stringstream actual_kernel; gpulw.printKernel(actual_kernel); TORCH_CHECK(at::allclose(kernel_tv4, t4), actual_kernel.str()); } void testGPU_FusionLoopUnroll() { Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(3); TensorView* tv1 = makeDummyTensor(3); // Register your inputs fusion.addInput(tv0); fusion.addInput(tv1); // Do math with it, it returns a `Val*` but can be static_casted back to // TensorView TensorView* tv2 = add(tv1, new Float(2.0)); TensorView* tv3 = add(tv0, tv2); // Register your outputs fusion.addOutput(tv3); int block_size = 16; tv3->merge(0, 1); tv3->merge(0, 1); tv3->split(0, block_size); tv3->split(0, 4); // For all inputs, computeAt the output inline, temporaries should be squeezed // between them tv0->computeAt(tv3, 1); tv1->computeAt(tv3, 1); // Parallelize tv2->axis(1)->parallelize(ParallelType::Unroll); tv3->axis(1)->parallelize(ParallelType::Unroll); tv2->axis(-1)->parallelize(ParallelType::TIDx); tv3->axis(-1)->parallelize(ParallelType::TIDx); tv3->axis(0)->parallelize(ParallelType::BIDx); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input0 = at::rand({129, 13, 3}, options); at::Tensor input1 = at::rand({129, 13, 3}, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({input0, input1}); TORCH_CHECK(outputs[0].equal(input0.add(input1.add(2.0)))); } /* * Helper function for single op testing that generates a codegen operand */ Val* gen_jit_operand(std::pair desc) { if (desc.first == ValType::TensorView) { return makeDummyTensor(2, desc.second); } else if (desc.first == ValType::Scalar) { if (desc.second == DataType::Float) return new Float(); else if (desc.second == DataType::Int) return new Int(); else TORCH_CHECK("Not currently supported type", desc.first); } else { TORCH_CHECK("Not currently supported type", desc.first); } return nullptr; } /* * Helper function for single op testing that generates an ATen operand */ IValue gen_aten_operand( std::pair desc, int blocks, int threads, bool rand) { if (desc.first == ValType::TensorView) { if (desc.second == DataType::Float) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); if (rand) return IValue(at::rand({blocks, threads}, options)); else return IValue(at::empty({blocks, threads}, options)); } else if (desc.second == DataType::Half) { auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); if (rand) return IValue(at::rand({blocks, threads}, options)); else return IValue(at::empty({blocks, threads}, options)); } else if (desc.second == DataType::Bool) { if (rand) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); return IValue(at::rand({blocks, threads}, options).to(at::kBool)); } else { auto options = at::TensorOptions().dtype(at::kBool).device(at::kCUDA, 0); return IValue(at::empty({blocks, threads}, options)); } } else { TORCH_CHECK("Not currently supported type", desc.second) } } else if (desc.first == ValType::Scalar) { if (desc.second == DataType::Float) return IValue(at::Scalar(1.f)); else if (desc.second == DataType::Int) return IValue(at::Scalar(1)); else TORCH_CHECK("Not currently supported type", desc.first); } else { TORCH_CHECK("Not currently supported type", desc.first); } return nullptr; } /* * Templatized Helper Function To generate single Op comparison between the * JIT codegen for Cuda and the ATen Library. */ using OutputPair = std::pair; template < typename AtenFunc, typename JitFunc, typename InputTuple, size_t... NumInputs> void test_op( int blocks, int threads, std::string op_str, AtenFunc af, JitFunc jf, OutputPair op, InputTuple it, std::index_sequence) { Fusion fusion; FusionGuard fg(&fusion); // Generate Input JIT function Inputs and add them as Inputs to the Fusion // Graph std::array jit_inputs = { gen_jit_operand(std::get(it))...}; std::for_each(jit_inputs.begin(), jit_inputs.end(), [&fusion](Val* v) { fusion.addInput(v); }); TensorView* out = static_cast(jf(std::get(jit_inputs)...)); fusion.addOutput(out); std::for_each(jit_inputs.begin(), jit_inputs.end(), [out](Val* v) { if (v->getValType() == ValType::TensorView) static_cast(v)->computeAt(out, -1); }); out->axis(0)->parallelize(ParallelType::BIDx); out->axis(-1)->parallelize(ParallelType::TIDx); std::array aten_inputs = {gen_aten_operand( std::get(it), blocks, threads, /*rand*/ true)...}; const at::ArrayRef aten_inputs_ivalues(aten_inputs); at::Tensor output = gen_aten_operand(op, blocks, threads, /*rand*/ false).toTensor(); std::vector output_vect = {output}; cudaDeviceSynchronize(); if (fusion.hasRNG()) at::manual_seed(0); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion(aten_inputs_ivalues, output_vect); cudaDeviceSynchronize(); if (fusion.hasRNG()) at::manual_seed(0); at::Tensor ref_output = af(aten_inputs); cudaDeviceSynchronize(); // This sync shouldn't be necessary; std::function aten_inputs_to_str = [&aten_inputs]() -> std::string { int input_cnt = 1; std::stringstream ss; std::for_each( aten_inputs.begin(), aten_inputs.end(), [&input_cnt, &ss](IValue& iv) { ss << "\nINPUT" << input_cnt++ << ": " << iv.toTensor(); }); return ss.str(); }; at::Tensor diff; if (output.scalar_type() == at::kBool) { diff = at::eq(output, ref_output); } else { diff = at::sub(output, ref_output); } TORCH_CHECK( (output.scalar_type() == at::kBool ? output.equal(ref_output) : // The absolute Tolerance was raised to 1e-07 from 1e-08 to allow // allow for the remainder function to pass. output.allclose(ref_output, /*rtol*/ 1e-05, /*atol*/ 1e-07)), "\nOp Type: -- ", op_str, " -- had a mismatch.", aten_inputs_to_str(), "\nJIT: ", output, "\nREF: ", ref_output, "\nDIFF: ", diff, "\n"); } /* * Templatized Helper Function that uses variadic templates to * process a variable length Input Tuple of different Operand Type. */ template void test_op( int blocks, int threads, std::string op_str, AtenFunc af, JitFunc jf, OutputPair op, InputTuple it) { static constexpr auto size = std::tuple_size::value; test_op( blocks, threads, op_str, af, jf, op, it, std::make_index_sequence{}); } void testGPU_FusionUnaryOps() { using OpTuple = std::tuple; // [Note: explicit tuple type for uniform initialization list] // Tuple type must be explicitly specified for each uniform initialization // list within the vector to make this code compatible with some old env // which we still need to support. eg. gcc 5.4 + cuda 9.2. std::vector ops{ OpTuple{at::abs, UnaryOpType::Abs, "abs"}, OpTuple{at::acos, UnaryOpType::Acos, "acos"}, OpTuple{at::asin, UnaryOpType::Asin, "asin"}, OpTuple{at::atan, UnaryOpType::Atan, "atan"}, // There does not appear to be an appropriate ATen function for atanh // OpTuple{at::atanh, UnaryOpType::Atanh, "atanh" }, OpTuple{at::ceil, UnaryOpType::Ceil, "ceil"}, OpTuple{at::cos, UnaryOpType::Cos, "cos"}, OpTuple{at::cosh, UnaryOpType::Cosh, "cosh"}, OpTuple{at::erf, UnaryOpType::Erf, "erf"}, OpTuple{at::erfc, UnaryOpType::Erfc, "erfc"}, OpTuple{at::exp, UnaryOpType::Exp, "exp"}, OpTuple{at::expm1, UnaryOpType::Expm1, "expm1"}, OpTuple{at::floor, UnaryOpType::Floor, "floor"}, OpTuple{at::frac, UnaryOpType::Frac, "frac"}, OpTuple{at::gelu, UnaryOpType::Gelu, "gelu"}, OpTuple{at::lgamma, UnaryOpType::Lgamma, "lgamma"}, OpTuple{at::log, UnaryOpType::Log, "log"}, OpTuple{at::log10, UnaryOpType::Log10, "log10"}, OpTuple{at::log1p, UnaryOpType::Log1p, "log1p"}, OpTuple{at::log2, UnaryOpType::Log2, "log2"}, OpTuple{at::neg, UnaryOpType::Neg, "neg"}, OpTuple{at::reciprocal, UnaryOpType::Reciprocal, "reciprocal"}, OpTuple{at::relu, UnaryOpType::Relu, "relu"}, OpTuple{at::round, UnaryOpType::Round, "round"}, OpTuple{at::rsqrt, UnaryOpType::Rsqrt, "rsqrt"}, OpTuple{at::sigmoid, UnaryOpType::Sigmoid, "sigmoid"}, OpTuple{at::sin, UnaryOpType::Sin, "sin"}, OpTuple{at::sinh, UnaryOpType::Sinh, "sinh"}, OpTuple{at::sqrt, UnaryOpType::Sqrt, "sqrt"}, OpTuple{at::tan, UnaryOpType::Tan, "tan"}, OpTuple{at::tanh, UnaryOpType::Tanh, "tanh"}, OpTuple{at::trunc, UnaryOpType::Trunc, "trunc"}}; std::for_each(ops.begin(), ops.end(), [](OpTuple& op) { test_op( /*blocks*/ 640, /*threads*/ 64, /*name*/ std::get<2>(op), /*Aten Func */ [&op](std::array& vals) { return std::get<0>(op)(vals[0].toTensor()); }, /*JIT Func */ [&op](Val* in1) -> Val* { return unaryOp(std::get<1>(op), in1); }, /*Output */ std::make_pair(ValType::TensorView, DataType::Float), /*Inputs Tuple*/ std::make_tuple(std::make_pair(ValType::TensorView, DataType::Float))); }); test_op( /*blocks*/ 128, /*threads*/ 64, /*name*/ "rand_like", /*Aten Func */ [](std::array& vals) { return at::rand_like(vals[0].toTensor()); }, /*JIT Func */ [](Val* in1) -> Val* { return unaryOp(UnaryOpType::RandLike, in1); }, /*Output */ std::make_pair(ValType::TensorView, DataType::Float), /*Inputs Tuple*/ std::make_tuple(std::make_pair(ValType::TensorView, DataType::Float))); } void testGPU_FusionBinaryOps() { using AtenFuncSig = at::Tensor (*)(const at::Tensor&, const at::Tensor&); using OpTuple = std::tuple; // see [Note: explicit tuple type for uniform initialization list] std::vector logic_ops{OpTuple{at::eq, BinaryOpType::Eq, "eq"}, OpTuple{at::ge, BinaryOpType::GE, "ge"}, OpTuple{at::gt, BinaryOpType::GT, "gt"}, OpTuple{at::le, BinaryOpType::LE, "le"}, OpTuple{at::lt, BinaryOpType::LT, "lt"}, OpTuple{at::ne, BinaryOpType::NE, "ne"}}; std::for_each(logic_ops.begin(), logic_ops.end(), [](OpTuple& op) { test_op( /*blocks*/ 640, /*threads*/ 64, /*name*/ std::get<2>(op), /*Aten Func */ [&op](std::array& vals) { return std::get<0>(op)(vals[0].toTensor(), vals[1].toTensor()); }, /*JIT Func */ [&op](Val* in1, Val* in2) -> Val* { return binaryOp(std::get<1>(op), in1, in2); }, /*Output */ std::make_pair(ValType::TensorView, DataType::Bool), /*Inputs Tuple*/ std::make_tuple( std::make_pair(ValType::TensorView, DataType::Float), std::make_pair(ValType::TensorView, DataType::Float))); }); // see [Note: explicit tuple type for uniform initialization list] std::vector math_ops{ OpTuple{at::atan2, BinaryOpType::Atan2, "atan2"}, OpTuple{at::div, BinaryOpType::Div, "div"}, OpTuple{at::fmod, BinaryOpType::Fmod, "fmod"}, OpTuple{at::max, BinaryOpType::Max, "max"}, OpTuple{at::min, BinaryOpType::Min, "min"}, OpTuple{at::mul, BinaryOpType::Mul, "mul"}, OpTuple{at::pow, BinaryOpType::Pow, "pow"}, // NOTE: Remainder does not match the Aten impl exactly // despite using an identical function. OpTuple{at::remainder, BinaryOpType::Remainder, "remainder"}, }; std::for_each(math_ops.begin(), math_ops.end(), [](OpTuple& op) { test_op( /*blocks*/ 640, /*threads*/ 64, /*name*/ std::get<2>(op), /*Aten Func */ [&op](std::array& vals) { return std::get<0>(op)(vals[0].toTensor(), vals[1].toTensor()); }, /*JIT Func */ [&op](Val* in1, Val* in2) -> Val* { return binaryOp(std::get<1>(op), in1, in2); }, /*Output */ std::make_pair(ValType::TensorView, DataType::Float), /*Inputs Tuple*/ std::make_tuple( std::make_pair(ValType::TensorView, DataType::Float), std::make_pair(ValType::TensorView, DataType::Float))); }); test_op( /*blocks*/ 640, /*threads*/ 64, /*name*/ "add_alpha", /*Aten Func */ [](std::array& vals) { return at::add( vals[0].toTensor(), vals[1].toTensor(), vals[2].toScalar()); }, /*JIT Func */ static_cast(&add_alpha), /*Output */ std::make_pair(ValType::TensorView, DataType::Float), /*Inputs Tuple*/ std::make_tuple( std::make_pair(ValType::TensorView, DataType::Float), std::make_pair(ValType::TensorView, DataType::Float), std::make_pair(ValType::Scalar, DataType::Float))); test_op( /*blocks*/ 640, /*threads*/ 64, /*name*/ "sub_alpha", /*Aten Func */ [](std::array& vals) { return at::sub( vals[0].toTensor(), vals[1].toTensor(), vals[2].toScalar()); }, /*JIT Func */ static_cast(&sub_alpha), /*Output */ std::make_pair(ValType::TensorView, DataType::Float), /*Inputs Tuple*/ std::make_tuple( std::make_pair(ValType::TensorView, DataType::Float), std::make_pair(ValType::TensorView, DataType::Float), std::make_pair(ValType::Scalar, DataType::Float))); } void testGPU_FusionTernaryOps() { test_op( /*blocks*/ 640, /*threads*/ 64, /*name*/ "clamp", /*Aten Func */ [](std::array& vals) { return at::clamp(vals[0].toTensor(), 0.f, 1.f); }, /*JIT Func */ [](Val* in1) -> Val* { return clamp(in1, new Float(0.f), new Float(1.f)); }, /*Output */ std::make_pair(ValType::TensorView, DataType::Float), /*Inputs Tuple*/ std::make_tuple(std::make_pair(ValType::TensorView, DataType::Float))); test_op( /*blocks*/ 640, /*threads*/ 64, /*name*/ "threshold", /*Aten Func */ [](std::array& vals) { return at::threshold(vals[0].toTensor(), 0.f, 1.f); }, /*JIT Func */ [](Val* in1) -> Val* { return threshold(in1, new Float(0.f), new Float(1.f)); }, /*Output */ std::make_pair(ValType::TensorView, DataType::Float), /*Inputs Tuple*/ std::make_tuple(std::make_pair(ValType::TensorView, DataType::Float))); test_op( /*blocks*/ 640, /*threads*/ 64, /*name*/ "where", /*Aten Func */ [](std::array& vals) { return at::where( vals[0].toTensor(), vals[1].toTensor(), vals[2].toTensor()); }, /*JIT Func */ static_cast(&where), /*Output */ std::make_pair(ValType::TensorView, DataType::Float), /*Inputs Tuple*/ std::make_tuple( std::make_pair(ValType::TensorView, DataType::Bool), std::make_pair(ValType::TensorView, DataType::Float), std::make_pair(ValType::TensorView, DataType::Float))); } void testGPU_FusionCompoundOps() { test_op( /*blocks*/ 640, /*threads*/ 64, /*name*/ "lerp", /*Aten Func */ [](std::array& vals) { return at::lerp( vals[0].toTensor(), vals[1].toTensor(), vals[2].toTensor()); }, /*JIT Func */ static_cast(&lerp), /*Output */ std::make_pair(ValType::TensorView, DataType::Float), /*Inputs Tuple*/ std::make_tuple( std::make_pair(ValType::TensorView, DataType::Float), std::make_pair(ValType::TensorView, DataType::Float), std::make_pair(ValType::TensorView, DataType::Float))); test_op( /*blocks*/ 640, /*threads*/ 64, /*name*/ "addcmul", /*Aten Func */ [](std::array& vals) { return at::addcmul( vals[0].toTensor(), vals[1].toTensor(), vals[2].toTensor(), vals[3].toScalar()); }, /*JIT Func */ static_cast(&addcmul), /*Output */ std::make_pair(ValType::TensorView, DataType::Float), /*Inputs Tuple*/ std::make_tuple( std::make_pair(ValType::TensorView, DataType::Float), std::make_pair(ValType::TensorView, DataType::Float), std::make_pair(ValType::TensorView, DataType::Float), std::make_pair(ValType::Scalar, DataType::Float))); } void testGPU_FusionCastOps() { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeDummyTensor(2, DataType::Half); TensorView* intrm1 = castOp(DataType::Float, tv0); TensorView* out = castOp(DataType::Half, intrm1); fusion.addInput(tv0); fusion.addOutput(out); tv0->computeAt(out, -1); out->axis(0)->parallelize(ParallelType::BIDx); out->axis(-1)->parallelize(ParallelType::TIDx); auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); at::Tensor input1 = at::rand({1, 4}, options); at::Tensor ref_output = at::empty_like(input1); std::array inputs = {input1}; const at::ArrayRef input_ivalues(inputs); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion(input_ivalues); ref_output = at::_cast_Half(at::_cast_Float(input1)); TORCH_CHECK( outputs[0].equal(ref_output), "\nOp Type: -- ", "cast FP16->FP32->FP16", " -- had a mismatch.\n", "IN1 : ", input1, "\n", "JIT: ", outputs[0], "\n", "REF: ", ref_output, "\n"); } // We want split/merge/reorder all tested both on and off rfactor domains, also // want compute at into the rfactor domain, and into its consumer void testGPU_FusionRFactorReplay() { Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); // Register your inputs fusion.addInput(tv0); // Do math with it, it returns a `Val*` but can be static_casted back to // TensorView TensorView* tv1 = sum(tv0, {1}); // tv1[I0, R1] tv1->split(0, 32); // tv1[I0o, I0i{32}, R1] tv1->split(0, 16); // tv1[I0oo, I0oi{16}, I0i{32}, R1] tv1->split(-1, 8); // tv1[I0oo, I0oi{16}, I0i{32}, R1o, R1i{8}] tv1->split(-2, 4); // tv1[I0oo, I0oi{16}, I0i{32}, R1oo, R1oi{4}, R1i{8}] tv1->reorder({{0, -2}, {2, -1}, {-3, 0}, {-1, 1}}); // tv1[R1oo, R1i{8}, I0oi{16}, R1oi{4}, I0oo, I0i{32}] tv1->merge(0); tv1->merge(-2); // tv1[R1oo*R1i{8}, I0oi{16}, R1oi{4}, I0oo*I0i{32}] TensorDomain* new_domain = TransformRFactor::runReplay(tv1->domain(), {0}); // new_domain[r(R1oo*R1i{8})rf, I0oi{16}, ir1oi{4}rf, I0oo*I0i{32}] TensorDomain* new_domain2 = TransformRFactor::runReplay2(tv1->domain(), {0}); // new_domain2[ I0oi{16}, , I0oo*I0i{32}, R1oi{4}] // Move rfactor axis to end, keep iter rfactor axis new_domain->reorder({{0, -1}, {2, 2}}); // Replay casp, replay new_domain2 as new_domain // reordered_new_domain[I0oi{16}, I0oo*I0i{32}, ir1oi{4}rf, R(R1oo*R1i{8})rf] auto replay_casp = TransformReplay::replayCasP(new_domain2, new_domain, 2); TensorDomain* casp = replay_casp.first; // new_domain[I0oi{16}, I0oo*I0i{32}, ir1oi{4}rf, R(R1oo*R1i{8})rf] // casp[I0oi{16}, I0oo*I0i{32}, R1oi{4}] casp->split(1, new Int(2)); // casp [I0oi{16}, (I0oo*I0i{32})o, I(Ioo*I0i)i{2}, ir1oi{4} ] // new_domain[I0oi{16}, I0oo*I0i{32} , ir1oi{4}rf, // R(R1oo*R1i{8})rf] auto replay_pasc = TransformReplay::replayPasC(new_domain, casp, 2); TensorDomain* pasc = replay_pasc.first; // pasc [I0oi{16}, (I0oo*I0i{32})o, I(Ioo*I0i)i{2}, ir1oi{4}rf, // R(R1oo*R1i{8})rf] TORCH_CHECK( new_domain->nDims() - 1 == new_domain2->nDims(), casp->nDims() == new_domain2->nDims() + 1, pasc->nDims() == new_domain->nDims() + 1, "Error in rfactor, number of dimensions is not correct."); TORCH_CHECK( !casp->sameAs(new_domain2) && !pasc->sameAs(new_domain) && !new_domain->sameAs(new_domain2) && !tv1->domain()->sameAs(new_domain) && !tv1->domain()->sameAs(new_domain2), "Error in rfactor, number of dimensions is not correct."); auto dom = new_domain->getRootDomain(); TORCH_CHECK( !dom[0]->isReduction() && std::any_of( dom.begin(), dom.end(), [](IterDomain* id) { return id->isReduction(); }) && std::any_of( dom.begin(), dom.end(), [](IterDomain* id) { return id->isRFactorProduct(); }), "Error in rFactor, there seems to be something wrong in root domain."); auto dom2 = new_domain2->getRootDomain(); TORCH_CHECK( !dom2[0]->isReduction() && std::any_of( dom2.begin(), dom2.end(), [](IterDomain* id) { return id->isReduction(); }), "Error in rFactor, there seems to be something wrong in root domain."); } // Start off simple, block on the outer dim // block stride + thread all reduce + unrolling on inner dim void testGPU_FusionReduction() { Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); fusion.addInput(tv0); // tv1[I0, R1] = tv0[I0, I1] TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0); fusion.addOutput(tv1); TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion."); tv1->split(1, 128); // tv1[I0, R1o, R1i{128}] = tv0[I0, I1] tv1->split(1, 4); // tv1[I0, R1oo, R1oi{4}, R1i{128}] = tv0[I0, I1] TensorView* tv2 = tv1->rFactor({1}); // tv2[I0, R1oo, Ir1oi{4}, Ir1i{128}] = tv0[I0, I1] // tv1[I0, R1oi{4}, R1i{128}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{128}] TensorView* tv3 = tv1->rFactor({1}); // tv2[I0, R1oo, Ir1oi{4}, Ir1i{128}] = tv0[I0, I1] // tv3[I0, R1oi{4}, Ir1i{128}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{128}] // tv1[I0, R1i{128}] = tv3[I0, R1oi{4}, Ir1i{128}] // Incrementally, can print in between for debugging tv0->computeAt(tv2, 1); tv2->computeAt(tv3, 1); tv3->computeAt(tv1, 1); // Re do it all at once, because why not. tv0->computeAt(tv1, 1); tv2->axis(2)->parallelize(ParallelType::Unroll); tv1->axis(0)->parallelize(ParallelType::BIDx); tv1->axis(-1)->parallelize(ParallelType::TIDx); tv2->axis(-1)->parallelize(ParallelType::TIDx); tv3->axis(-1)->parallelize(ParallelType::TIDx); int numel_x = 65000; int numel_y = 1025; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({numel_x, numel_y}, options); at::Tensor cg_output = at::empty({numel_x}, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({input}, {cg_output}); auto aten_output = input.sum({1}); TORCH_CHECK(aten_output.allclose(cg_output)); } void testGPU_FusionReduction2() { { Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); fusion.addInput(tv0); // tv1[I0, R1] = tv0[I0, I1] TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0); fusion.addOutput(tv1); // switches to try some different scenarios. maybe we should iterate on all // permutations. bool bind_bidx = true; bool bind_tidx = true; bool bind_tidy = true; bool bind_unroll = true; int numel_x = 1025; // Cannot exceed block dim max size / tidy int numel_y = 129; int tidx = 16; int tidy = 8; int unroll_factor = 4; tv1->split(1, tidx); // tv1[I0, R1o, R1i{tidx}] = tv0[I0, I1] tv1->split(1, unroll_factor); // tv1[I0, R1oo, R1oi{unroll}, R1i{tidx}] = tv0[I0, I1] tv1->split(0, tidy); TensorView* tv2 = tv1->rFactor({-3}); // tv2[I0, >R1oo<, Ir1oi{unroll}, Ir1i{tidx}] // tv1[I0o, I0i{tidy}, R1oi{unroll}, R1i{tidx}] TensorView* tv3 = tv1->rFactor({-2}); // tv2[I0, >R1oo<, Ir1oi{unroll}, Ir1i{tidx}] // tv3[I0, R1oi{unroll}, Ir1i{tidx}] // tv1[I0o, I0i{tidy}, R1i{tidx}] tv0->computeAt(tv1, -2); if (bind_unroll) tv2->axis(-2)->parallelize(ParallelType::Unroll); if (bind_bidx) tv1->axis(0)->parallelize(ParallelType::BIDx); if (bind_tidy) tv1->axis(1)->parallelize(ParallelType::TIDy); if (bind_tidx) { tv2->axis(-1)->parallelize(ParallelType::TIDx); tv3->axis(-1)->parallelize(ParallelType::TIDx); tv1->axis(-1)->parallelize(ParallelType::TIDx); } auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({numel_x, numel_y}, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({input}); c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(); AT_CUDA_CHECK(cudaStreamSynchronize(stream)); auto aten_output = input.sum({1}); TORCH_CHECK(aten_output.allclose(outputs[0])); } { // What if Z participates in the reduction with X? Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); fusion.addInput(tv0); // tv1[I0, R1] = tv0[I0, I1] TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0); fusion.addOutput(tv1); int numel_x = 1025; // Cannot exceed block dim max size / tidy int numel_y = 129; int tidx = 16; int tidz = 8; tv1->split(1, tidz); // tv1[I0, R1o, R1i{tidz}] = tv0[I0, I1] tv1->split(1, tidx); // tv1[I0, R1oo, R1oi{tidx}, R1i{tidz}] = tv0[I0, I1] TensorView* tv2 = tv1->rFactor({-3}); // tv2[I0, >R1oo<, Ir1oi{tidx}, Ir1i{tidz}] // tv1[I0o, R1oi{tidx}, R1i{tidz}] tv0->computeAt(tv1, -3); tv1->axis(0)->parallelize(ParallelType::BIDx); tv1->axis(-2)->parallelize(ParallelType::TIDx); tv1->axis(-1)->parallelize(ParallelType::TIDz); tv2->axis(-2)->parallelize(ParallelType::TIDx); tv2->axis(-1)->parallelize(ParallelType::TIDz); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({numel_x, numel_y}, options); at::Tensor cg_output = at::empty({numel_x}, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({input}, {cg_output}); c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(); AT_CUDA_CHECK(cudaStreamSynchronize(stream)); auto aten_output = input.sum({1}); TORCH_CHECK(aten_output.allclose(cg_output)); } } void testGPU_FusionReduction3() { { Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); TensorView* tv1 = makeDummyTensor(2); TensorView* tv2 = add(tv0, tv1); // tv2[I0, I1] = tv0[I0, I1] + tv1[I0, I1] fusion.addInput(tv0); fusion.addInput(tv1); TensorView* tv3 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv2); // tv3[I0, R1] = tv2[I0, I1] TensorView* tv4 = makeDummyTensor(1); fusion.addInput(tv4); // tv5[I0] = tv3[I0, R1] * tv4[I0] TensorView* tv5 = mul(tv3, tv4); fusion.addOutput(tv5); int tidx = 16; // RFactor the reduction tv3->split(1, tidx); // tv3[I0, R1o, R1i{tidx}] = tv2[I0, I1] TensorView* tv6 = tv3->rFactor({-2}); // tv6[I0, R1o, iR1i{tidx}] = tv2[I0, I1] // tv3[I0, R1i{tidx}] = tv3[I0, I1] tv2->computeAt(tv6, 2); // Compute at inline with tv5 (only 1D) tv6->computeAt(tv3, 1); tv3->computeAt(tv5, 1); tv5->axis(0)->parallelize(ParallelType::BIDx); // Intermediate tensors only need this, but doesn't hurt to do on inputs // tv0, 1, 4 tv2->axis(-1)->parallelize(ParallelType::TIDx); tv3->axis(-1)->parallelize(ParallelType::TIDx); tv6->axis(-1)->parallelize(ParallelType::TIDx); int numel_x = 1025; int numel_y = 129; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::rand({numel_x, numel_y}, options); at::Tensor t1 = at::rand({numel_x, numel_y}, options); auto t2 = t0.add(t1); auto t3 = t2.sum({1}); at::Tensor t4 = at::rand({numel_x}, options); auto t5 = t3.mul(t4); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({t0, t1, t4}); c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(); AT_CUDA_CHECK(cudaStreamSynchronize(stream)); TORCH_CHECK( t5.allclose(outputs[0]), "Error of: ", t5.sub(outputs[0]).abs().max()); } } void testGPU_FusionReduction4() { Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(3); fusion.addInput(tv0); TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0); fusion.addOutput(tv1); int bidy = 2; int tidy = 4; int tidx = 5; int dim1 = 11; tv1->split(-2, tidy); TensorView* tv2 = tv1->rFactor({-3}); tv0->computeAt(tv1, 1); tv1->axis(0)->parallelize(ParallelType::BIDy); for (auto* val : fusion.vals()) { if (!fusion.hasInput(val) && val->getValType().value() == ValType::TensorView) { val->as()->axis(-1)->parallelize(ParallelType::TIDx); } } tv2->axis(-2)->parallelize(ParallelType::TIDy); tv1->axis(-2)->parallelize(ParallelType::TIDy); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::randn({bidy, dim1, tidx}, options); at::Tensor cg_output = at::empty({bidy, tidx}, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({input}, {cg_output}); auto aten_output = input.sum({1}); TORCH_CHECK( aten_output.allclose(cg_output, 1e-5, 1e-7), "Error of: ", aten_output.sub(cg_output).abs().max()); } void testGPU_FusionReduction5() { Fusion fusion; FusionGuard fg(&fusion); const int bdimx = 64; const int bdimy = 8; // Set up your input tensor views TensorView* tv0 = makeDummyTensor(3); fusion.addInput(tv0); // tv1[I0, R1, R2] = tv0[I0, I1, I2] TensorView* tv1 = reductionOp(BinaryOpType::Add, {1, 2}, new Float(0), tv0); fusion.addOutput(tv1); TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion."); tv1->split(2, bdimx); // tv1[I0, R1, R2o, R2i{128}] = tv0[I0, I1, I2] tv1->split(1, bdimy); // tv1[I0, R1o, R1i{8}, R2o, R2i{128}] = tv0[I0, I1, I2] TensorView* tv2 = tv1->rFactor({3}); // tv2[I0, I1o, I1i{8}, R2o, I2i{128}] = tv0[I0, I1, I2] // tv1[I0, R1o, R1i{8}, R2i{128}] = tv2[I0, I1o, I1i{8}, R2o, I2i{128}] TensorView* tv3 = tv1->rFactor({1}); // tv2[I0, I1o, I1i{8}, R2o, I2i{128}] = tv0[I0, I1, I2] // tv3[I0, R1o, I1i{8}, I2i{128}] = tv2[I0, I1o, I1i{8}, R2o, I2i{128}] // tv1[I0, R1i{8}, R2i{128}] = tv3[I0, R1o, I1i{8}, I2i{128}] tv3->computeAt(tv1, 1); tv2->computeAt(tv3, 2); tv1->axis(0)->parallelize(ParallelType::BIDx); tv2->axis(0)->parallelize(ParallelType::BIDx); tv3->axis(0)->parallelize(ParallelType::BIDx); tv1->axis(-1)->parallelize(ParallelType::TIDx); tv2->axis(-1)->parallelize(ParallelType::TIDx); tv3->axis(-1)->parallelize(ParallelType::TIDx); tv1->axis(-2)->parallelize(ParallelType::TIDy); tv3->axis(-2)->parallelize(ParallelType::TIDy); tv2->axis(-3)->parallelize(ParallelType::TIDy); int numel_x = 650; int numel_y = 1000; int numel_z = 4; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({numel_x, numel_y, numel_z}, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({input}); auto aten_output = input.sum({1, 2}); TORCH_CHECK(aten_output.allclose(outputs[0])); } void testGPU_FusionReductionTFT() { Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); fusion.addInput(tv0); // tv1[I0, R1] = tv0[I0, I1] TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0); fusion.addOutput(tv1); int numel_x = 1025; int numel_y = 129; int tidx = 16; int tidy = 8; int tidz = 8; tv1->split(1, tidx); // tv1[I0, R1o, R1i{tidx}] tv1->split(1, tidz); // tv1[I0, R1oo, R1Oi{tidz}, R1R1i{tidx}] tv1->split(0, tidy); // tv1[I0o, I0i, R1oo, R1Oi{tidz}, R1R1i{tidx}] TensorView* tv2 = tv1->rFactor({2}); // tv2[I0o, I0i, R1oo, I1Oi{tidz}, I11i{tidx}] // tv1[I0o, I0i, R1Oi{tidz}, R1R1i{tidx}] tv2->computeAt(tv1, 2); tv1->axis(1)->parallelize(ParallelType::TIDy); tv2->axis(-1)->parallelize(ParallelType::TIDx); tv1->axis(-1)->parallelize(ParallelType::TIDx); tv1->axis(-2)->parallelize(ParallelType::TIDz); tv2->axis(-2)->parallelize(ParallelType::TIDz); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({numel_x, numel_y}, options); at::Tensor cg_output = at::empty({numel_x}, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({input}, {cg_output}); c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(); AT_CUDA_CHECK(cudaStreamSynchronize(stream)); auto aten_output = input.sum({1}); TORCH_CHECK(aten_output.allclose(cg_output)); } void testGPU_FusionBranches() { Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); TensorView* tv1 = makeDummyTensor(2); TensorView* tv2 = makeDummyTensor(2); fusion.addInput(tv0); fusion.addInput(tv1); fusion.addInput(tv2); auto tv3 = add(tv0, new Float(1.0)); auto tv4 = add(tv3, tv1); auto tv5 = add(tv3, tv2); auto tv6 = add(tv4, tv5); fusion.addOutput(tv6); constexpr int x = 63, y = 33; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({x, y}, options); at::Tensor t1 = at::randn({x, y}, options); at::Tensor t2 = at::randn({x, y}, options); torch::jit::fuser::cuda::FusionExecutor fe; tv6->merge(0); tv6->split(0, 128); tv6->split(0, 4); tv6->axis(0)->parallelize(ParallelType::BIDx); tv0->computeAt(tv6, 1); tv1->computeAt(tv6, 1); tv2->computeAt(tv6, 1); tv3->axis(-2)->parallelize(ParallelType::Unroll); tv3->axis(-1)->parallelize(ParallelType::TIDx); tv4->axis(-2)->parallelize(ParallelType::Unroll); tv4->axis(-1)->parallelize(ParallelType::TIDx); tv5->axis(-2)->parallelize(ParallelType::Unroll); tv5->axis(-1)->parallelize(ParallelType::TIDx); tv6->axis(-1)->parallelize(ParallelType::TIDx); fe.compileFusion(&fusion); auto outputs = fe.runFusion({t0, t1, t2}); auto t3 = t0.add(1.0); auto t4 = t3.add(t1); auto t5 = t3.add(t2); auto t6 = t4.add(t5); TORCH_CHECK(t6.allclose(outputs[0])); } void testGPU_FusionSimpleBCast() { { Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); fusion.addInput(tv0); TensorView* tv1 = add(tv0, new Float(1.5)); TensorView* tv2 = makeDummyTensor(2); fusion.addInput(tv2); TensorView* tv3 = makeDummyTensor(2); fusion.addInput(tv3); TensorView* tv4 = sub(tv2, tv3); TensorView* tv5 = broadcast(tv1, {false, false, true}); TensorView* tv6 = broadcast(tv4, {true, false, false}); TensorView* tv7 = add(tv5, tv6); fusion.addOutput(tv7); tv7->split(-1, 4); tv7->split(0, 8); tv0->computeAt(tv7, -1); tv2->computeAt(tv7, -1); tv7->axis(0)->parallelize(ParallelType::BIDx); tv7->axis(-1)->parallelize(ParallelType::TIDx); constexpr int x = 63, y = 33, z = 15; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({x, y}, options); at::Tensor t1 = t0.add(1.5); at::Tensor t2 = at::randn({y, z}, options); at::Tensor t3 = at::randn({y, z}, options); at::Tensor t4 = t2.sub(t3); at::Tensor t5 = t1.unsqueeze(-1).expand({x, y, z}); at::Tensor t6 = t4.expand({x, y, z}); at::Tensor t7 = t5.add(t6); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({t0, t2, t3}); TORCH_CHECK(t7.allclose(outputs[0])); } { Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); fusion.addInput(tv0); TensorView* tv1 = makeDummyTensor(2); fusion.addInput(tv1); TensorView* tv2 = add(tv0, tv1); TensorView* tv3 = broadcast(tv2, {false, false, true}); TensorView* tv4 = makeDummyTensor(2); fusion.addInput(tv4); TensorView* tv5 = sub(tv4, new Float(0.1)); TensorView* tv6 = broadcast(tv5, {true, false, false}); TensorView* tv7 = add(tv3, tv6); fusion.addOutput(tv7); tv7->merge(0, 1); tv0->computeAt(tv7, -1); tv4->computeAt(tv7, -1); tv7->axis(0)->parallelize(ParallelType::BIDx); tv7->axis(-1)->parallelize(ParallelType::TIDx); constexpr int x = 63, y = 33, z = 15; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({x, y}, options); at::Tensor t1 = at::randn({x, y}, options); at::Tensor t2 = t0.add(t1); at::Tensor t3 = t2.unsqueeze(-1).expand({x, y, z}); at::Tensor t4 = at::randn({y, z}, options); at::Tensor t5 = t4.sub(0.1); at::Tensor t6 = t5.expand({x, y, z}); at::Tensor t7 = t3.add(t6); at::Tensor cg_output = at::empty({x, y, z}, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({t0, t1, t4}, {cg_output}); TORCH_CHECK(t7.allclose(cg_output)); } { Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views std::vector dom; dom.push_back(new IterDomain(new Int(0), new Int())); dom.push_back(new IterDomain( new Int(0), new Int(1), ParallelType::Serial, IterType::BroadcastWithStride)); // tv0[I1, B{1}] TensorView* tv0 = new TensorView(new TensorDomain(dom), DataType::Float); fusion.addInput(tv0); // tv1[I0, I1, I2] TensorView* tv2 = makeDummyTensor(3); fusion.addInput(tv2); TensorView* tv3 = add(tv0, tv2); fusion.addOutput(tv3); tv3->merge(0); tv3->merge(0); tv0->computeAt(tv3, -1); tv2->computeAt(tv3, -1); tv3->axis(0)->parallelize(ParallelType::BIDx); constexpr int x = 2, y = 3, z = 4; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({y, 1}, options); at::Tensor t2 = at::randn({x, y, z}, options); auto t3 = t0.add(t2); at::Tensor cg_output = at::empty({x, y, z}, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({t0, t2}, {cg_output}); TORCH_CHECK(t3.allclose(cg_output)); } { Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views std::vector dom; dom.push_back(new IterDomain( new Int(0), new Int(1), ParallelType::Serial, IterType::BroadcastWithStride)); dom.push_back(new IterDomain(new Int(0), new Int())); TensorView* tv0 = new TensorView(new TensorDomain(dom), DataType::Float); TensorView* tv1 = makeDummyTensor(3); fusion.addInput(tv0); fusion.addInput(tv1); TensorView* tv3 = add(tv0, tv1); tv3->merge(0); tv3->merge(0); tv3->split(0, 128); tv3->split(0, 4); fusion.addOutput(tv3); tv0->computeAt(tv3, -1); tv1->computeAt(tv3, -1); tv3->axis(0)->parallelize(ParallelType::BIDx); tv3->axis(-1)->parallelize(ParallelType::TIDx); tv3->axis(-2)->parallelize(ParallelType::Unroll); constexpr int x = 63, y = 33, z = 15; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({1, z}, options); at::Tensor t1 = at::randn({x, y, z}, options); at::Tensor cg_output = at::empty({x, y, z}, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({t0, t1}, {cg_output}); auto t3 = t0.add(t1); TORCH_CHECK(t3.allclose(cg_output)); } { Fusion fusion; FusionGuard fg(&fusion); constexpr int m = 2, k = 3, n = 4; auto zero = new Int(0); auto M = new IterDomain(zero, new Int(m)); auto K = new IterDomain(zero, new Int(k)); auto N = new IterDomain(zero, new Int(n)); // Set up your input tensor views TensorView* tv0 = new TensorView(new TensorDomain({M, K}, {true, true}), DataType::Float); TensorView* tv1 = new TensorView(new TensorDomain({K, N}, {true, true}), DataType::Float); fusion.addInput(tv0); fusion.addInput(tv1); TensorView* tv2 = broadcast(tv0, {false, false, true}); TensorView* tv3 = broadcast(tv1, {true, false, false}); TensorView* tv4 = add(tv2, tv3); fusion.addOutput(tv4); tv4->merge(0); tv4->merge(0); tv0->computeAt(tv4, -1); tv1->computeAt(tv4, -1); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({m, k}, options); at::Tensor t1 = at::randn({k, n}, options); at::Tensor cg_output = at::empty({m, k, n}, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({t0, t1}, {cg_output}); auto t2 = t0.unsqueeze(-1).expand({m, k, n}); auto t3 = t1.expand({m, k, n}); auto t4 = t2.add(t3); TORCH_CHECK(t4.allclose(cg_output)); } } void testGPU_FusionComplexBCast() { { Fusion fusion; FusionGuard fg(&fusion); int x = 2, y = 3, z = 4; auto tv0 = makeConcreteTensor({y}); auto tv1 = div(tv0, new Float(2.0)); auto tv2 = broadcast(tv1, {false, true}); auto tv3 = makeConcreteTensor({y, z}); auto tv4 = mul(tv2, tv3); auto tv5 = broadcast(tv4, {true, false, false}); auto tv6 = makeConcreteTensor({x, y, z}); auto tv7 = add(tv5, tv6); // tv0[ i1 ] = input // tv1[ i1 ] = tv0/2.0 // tv2[ i1, b2] = bcast(tv1) // tv3[ i1, i2] = input // tv4[ i1, i2] = tv2 * tv3 // tv5[b0, i1, i2] = bcast(tv4) // tv6[i0, i1, i2] = input // tv7[i0, i1, i2] = tv5 + tv6 // tv4 = bcast(tv1) * tv3 // tv7 = bcast(tv4) + tv6 fusion.addInput(tv0); fusion.addInput(tv3); fusion.addInput(tv6); fusion.addOutput(tv7); tv7->merge(0); tv7->merge(0); tv0->computeAt(tv7, -1); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({y}, options); at::Tensor t3 = at::randn({y, z}, options); at::Tensor t6 = at::randn({x, y, z}, options); auto t4 = t0.div(2.0).unsqueeze(-1).expand({y, z}) * t3; auto t7 = t4.unsqueeze(0).expand({x, y, z}) + t6; torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({t0, t3, t6}); TORCH_CHECK(t7.allclose(outputs[0])); } { Fusion fusion; FusionGuard fg(&fusion); int x = 2, y = 3, z = 4; auto tv0 = makeConcreteTensor({y, z}); auto tv1 = div(tv0, new Float(2.0)); auto tv2 = sum(tv1, {1}); auto tv3 = broadcast(tv2, {true, false}); auto tv4 = makeConcreteTensor({x, y}); auto tv5 = add(tv3, tv4); // tv0[ i1, i2] = input // tv1[ i1, i2] = tv0/2.0 // tv2[ i1 ] = sum(tv1, 1) // tv3[b0, i1 ] = bcast(tv2) // tv4[i0, i1 ] = input // tv5[i0, i1 ] = tv3 + tv4 // tv2 = sum(tv0/2.0, 1) // tv5 = bcast(tv2) + tv4 fusion.addInput(tv0); fusion.addInput(tv4); fusion.addOutput(tv5); tv5->merge(0); tv0->computeAt(tv5, -1); tv1->computeAt(tv2, -1); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({y, z}, options); auto t1 = t0.div(2.0); auto t2 = t1.sum(1); auto t3 = t2.unsqueeze(0).expand({x, y}); at::Tensor t4 = at::randn({x, y}, options); auto t5 = t3.add(t4); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({t0, t4}); TORCH_CHECK(t5.allclose(outputs[0])); } } void testGPU_FusionAdvancedIndexing() { // Merging left to right is still broken in some instances. Indexing can't // complete because we assume we can simply traverse consumer->producer in the // index/extent map, but this case breaks this assumption. { Fusion fusion; FusionGuard fg(&fusion); int w = 3, x = 4, y = 7, z = 8; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); auto tv0 = makeDummyTensor(3); auto tv1 = makeDummyTensor(4); fusion.addInput(tv0); fusion.addInput(tv1); auto tv2 = add(tv0, new Float(1.0)); auto tv3 = broadcast(tv2, {true, false, false, false}); auto tv4 = add(tv3, tv1); fusion.addOutput(tv4); tv4->merge(0); tv4->merge(0); tv4->merge(0); tv4->split(0, 128); tv4->split(0, 4); tv2->computeAt(tv4, 1); tv4->axis(0)->parallelize(ParallelType::BIDx); tv4->axis(1)->parallelize(ParallelType::Unroll); tv4->axis(2)->parallelize(ParallelType::TIDx); tv3->axis(1)->parallelize(ParallelType::Unroll); tv3->axis(2)->parallelize(ParallelType::TIDx); tv2->axis(1)->parallelize(ParallelType::Unroll); tv2->axis(2)->parallelize(ParallelType::TIDx); torch::jit::fuser::cuda::FusionExecutor fe; at::Tensor t0 = at::randn({x, y, z}, options); at::Tensor t1 = at::randn({w, x, y, z}, options); fe.compileFusion(&fusion); auto outputs = fe.runFusion({t0, t1}); auto t3 = t0.add(1.0); auto t4 = t3.add(t1); TORCH_CHECK(t4.allclose(outputs[0])); } // Merging right to left actually does work. { Fusion fusion; FusionGuard fg(&fusion); int w = 3, x = 4, y = 7, z = 8; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); auto tv0 = makeDummyTensor(3); auto tv1 = makeDummyTensor(4); fusion.addInput(tv0); fusion.addInput(tv1); auto tv2 = add(tv0, new Float(1.0)); auto tv3 = broadcast(tv2, {true, false, false, false}); auto tv4 = add(tv3, tv1); fusion.addOutput(tv4); tv4->merge(-2); tv4->merge(-2); tv4->merge(-2); tv4->split(0, 128); tv4->split(0, 4); tv2->computeAt(tv4, 1); tv4->axis(0)->parallelize(ParallelType::BIDx); tv4->axis(1)->parallelize(ParallelType::Unroll); tv4->axis(2)->parallelize(ParallelType::TIDx); tv3->axis(1)->parallelize(ParallelType::Unroll); tv3->axis(2)->parallelize(ParallelType::TIDx); tv2->axis(1)->parallelize(ParallelType::Unroll); tv2->axis(2)->parallelize(ParallelType::TIDx); torch::jit::fuser::cuda::FusionExecutor fe; at::Tensor t0 = at::randn({x, y, z}, options); at::Tensor t1 = at::randn({w, x, y, z}, options); fe.compileFusion(&fusion); auto outputs = fe.runFusion({t0, t1}); auto t3 = t0.add(1.0); auto t4 = t3.add(t1); TORCH_CHECK(t4.allclose(outputs[0])); } // Same issue as the first one in this section { Fusion fusion; FusionGuard fg(&fusion); int w = 3, x = 4, y = 7, z = 8; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({x, y, z}, options); at::Tensor t1 = at::randn({w, x, y, z}, options); auto tv0 = makeDummyTensor(3); auto tv1 = makeDummyTensor(4); fusion.addInput(tv0); fusion.addInput(tv1); auto tv2 = add(tv0, new Float(1.0)); auto tv3 = add(tv2, tv1); fusion.addOutput(tv3); fuser::cuda::scheduleFusion(&fusion, {t0, t1}); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({t0, t1}); auto t2 = t0.add(1.0); auto t3 = t2.add(t1); TORCH_CHECK(t3.allclose(outputs[0])); } } // Test a simple Gemm but also play around with fusion executor features void testGPU_FusionSimpleGemm() { Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); // M, K TensorView* tv1 = makeDummyTensor(2); // K, N fusion.addInput(tv0); fusion.addInput(tv1); TensorView* tv2 = broadcast(tv0, {false, false, true}); // tv2[I0, I1, B] = tv0[I0, I1] TensorView* tv3 = broadcast(tv1, {true, false, false}); // tv3[B, I1, I2] = tv1[I1, I2] // tv4[I0, I1, I2] = tv2[I0, I1, B] * tv3[B, I1, I2] TensorView* tv4 = mul(tv2, tv3); // tv5[I0, R1, I2] = tv4[I0, I1, I2] TensorView* tv5 = sum(tv4, {1}); fusion.addOutput(tv5); tv5->split(1, 32); // tv5[I0, R1o, R1i{32}, I2] auto tv6 = tv5->rFactor({1}); // tv6[I0, R1o, I1i{32}, I2] = tv4[I0, I1, I2] // tv5[I0, , R1i{32}, I2] = tv6[I0, R1o, I1i{32}, I2] tv5->split(0, 4); tv5->split(-1, 4); // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}] // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}] tv0->computeAt(tv5, -1); tv1->computeAt(tv5, -1); // tv6[I0o, I0i{4}, R1o, I1i{32}, I2o, I2i{4}] // tv5[I0o, I0i{4}, , R1i{32}, I2o, I2i{4}] //--> (line symbolizes compute at location) // tv4[I0o, I0i{4}, I1i{32}, I2o, I2i{4}|, I1o] // tv6[I0o, I0i{4}, I1i{32}, I2o, I2i{4}|, R1o] // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|] tv0->computeAt(tv6, -1); tv1->computeAt(tv6, -1); // tv4[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, I1o |] // tv6[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, R1o |] // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|] tv5->axis(0)->parallelize(ParallelType::BIDz); tv5->axis(1)->parallelize(ParallelType::TIDz); tv5->axis(-2)->parallelize(ParallelType::BIDy); tv5->axis(-1)->parallelize(ParallelType::TIDy); tv5->axis(2)->parallelize(ParallelType::TIDx); tv6->axis(2)->parallelize(ParallelType::TIDx); constexpr int M = 65, K = 33, N = 17; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({M, K}, options); at::Tensor t1 = at::randn({K, N}, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); // Lets specify a few bounds in launch params to make sure it works fe.runFusion( {t0, t1}, torch::jit::fuser::cuda::LaunchParams(1, -1, -1, 32, 4, 4)); // Make sure bad launch params throws ASSERT_ANY_THROW(fe.runFusion( {t0, t1}, torch::jit::fuser::cuda::LaunchParams(1, 2, 3, 4, 5, 6))); // Don't specify any launch params auto outputs = fe.runFusion({t0, t1}); auto t2 = t0.matmul(t1); TORCH_CHECK( t2.allclose(outputs[0], 1e-5, 1e-5), "Error of: ", t2.sub(outputs[0]).abs().max()); } // Softmax with a 1D tensor. Parallelized only with a single thread block. void testGPU_FusionSoftmax1D() { Fusion fusion; FusionGuard fg(&fusion); const int tidx = 128; const int dimx = 1000; // Set up your input tensor views TensorView* input_tv0 = makeDummyTensor(1); fusion.addInput(input_tv0); TensorView* exp_tv1 = unaryOp(UnaryOpType::Exp, input_tv0); TensorView* sum_exp_tv2 = sum(exp_tv1, {-1}); TensorView* bcast_sum_tv3 = broadcast(sum_exp_tv2, {true}); // Replicate exp_tv4 as exp_tv4_copy because exp_tv4 is going to be // computed at sum_exp_rf_tv8. TensorView* exp_tv1_copy = unaryOp(UnaryOpType::Exp, input_tv0); TensorView* output_tv4 = div(exp_tv1_copy, bcast_sum_tv3); fusion.addOutput(output_tv4); bcast_sum_tv3->split(0, tidx); sum_exp_tv2->split(-1, tidx); TensorView* sum_exp_rf_tv5 = sum_exp_tv2->rFactor({-2}); output_tv4->split(-1, tidx); exp_tv1->computeAt(sum_exp_rf_tv5, -1); exp_tv1_copy->computeAt(output_tv4, -1); TensorView* tensors_to_parallelize[] = { sum_exp_tv2, bcast_sum_tv3, output_tv4, sum_exp_rf_tv5}; for (auto tv : tensors_to_parallelize) { tv->axis(-1)->parallelize(ParallelType::TIDx); } auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({dimx}, options); at::Tensor cg_output = at::empty({dimx}, options); at::Tensor t3_output = at::empty_like(cg_output, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({t0}, {cg_output}); auto t2 = at::_softmax(t0, -1, false); TORCH_CHECK( t2.allclose(cg_output, 1e-5, 1e-5), "Error of: ", t2.sub(cg_output).abs().max()); } // Softmax with a 1D tensor with input normalization. void testGPU_FusionSoftmax1DNormalized() { Fusion fusion; FusionGuard fg(&fusion); const int tidx = 128; const int dimx = 1000; // Set up your input tensor views TensorView* input_tv0 = makeDummyTensor(1); fusion.addInput(input_tv0); // Normalize with the max value before computing exp. TensorView* max_val_tv1 = reductionOp(BinaryOpType::Max, {-1}, new Float(0), input_tv0); TensorView* bcast_max_tv2 = broadcast(max_val_tv1, {true}); TensorView* sub_tv3 = sub(input_tv0, bcast_max_tv2); TensorView* exp_tv4 = unaryOp(UnaryOpType::Exp, sub_tv3); TensorView* sum_exp_tv5 = sum(exp_tv4, {-1}); TensorView* bcast_sum_tv6 = broadcast(sum_exp_tv5, {true}); // Replicate exp_tv4 as exp_tv4_copy because exp_tv4 is going to be // computed at sum_exp_rf_tv8. TensorView* sub_tv3_copy = sub(input_tv0, bcast_max_tv2); TensorView* exp_tv4_copy = unaryOp(UnaryOpType::Exp, sub_tv3_copy); TensorView* output_tv7 = div(exp_tv4_copy, bcast_sum_tv6); fusion.addOutput(output_tv7); bcast_max_tv2->split(0, tidx); bcast_sum_tv6->split(0, tidx); max_val_tv1->split(-1, tidx); TensorView* max_val_rf_tv8 = max_val_tv1->rFactor({-2}); sum_exp_tv5->split(-1, tidx); TensorView* sum_exp_rf_tv9 = sum_exp_tv5->rFactor({-2}); output_tv7->split(-1, tidx); sub_tv3->computeAt(sum_exp_rf_tv9, -1); sub_tv3_copy->computeAt(output_tv7, -1); TensorView* tensors_to_parallelize[] = {max_val_tv1, bcast_max_tv2, sum_exp_tv5, bcast_sum_tv6, output_tv7, max_val_rf_tv8, sum_exp_rf_tv9}; for (auto tv : tensors_to_parallelize) { tv->axis(-1)->parallelize(ParallelType::TIDx); } auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({dimx}, options); at::Tensor t3_output = at::empty({dimx}, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({t0}); auto t2 = at::_softmax(t0, -1, false); TORCH_CHECK( t2.allclose(outputs[0], 1e-5, 1e-5), "Error of: ", t2.sub(outputs[0]).abs().max()); } // Softmax with a 3D tensor, where the inner-most 3rd dimension is // normalized. Pallelized with multiple thread blocks. void testGPU_FusionSoftmax3D() { Fusion fusion; FusionGuard fg(&fusion); const int tidx = 32; const int dimx = 32; const int dimy = 16; const int dimz = 130; // Set up your input tensor views TensorView* input_tv0 = makeDummyTensor(3); fusion.addInput(input_tv0); TensorView* exp_tv1 = unaryOp(UnaryOpType::Exp, input_tv0); TensorView* sum_exp_tv2 = sum(exp_tv1, {-1}); TensorView* bcast_sum_tv3 = broadcast(sum_exp_tv2, {false, false, true}); // Replicate exp_tv4 as exp_tv4_copy because exp_tv4 is going to be // computed at sum_exp_rf_tv8. TensorView* exp_tv1_copy = unaryOp(UnaryOpType::Exp, input_tv0); TensorView* output_tv4 = div(exp_tv1_copy, bcast_sum_tv3); fusion.addOutput(output_tv4); bcast_sum_tv3->split(-1, tidx); sum_exp_tv2->split(-1, tidx); TensorView* sum_exp_rf_tv5 = sum_exp_tv2->rFactor({-2}); output_tv4->split(-1, tidx); exp_tv1->computeAt(sum_exp_rf_tv5, -1); exp_tv1_copy->computeAt(output_tv4, -1); TensorView* tensors_to_parallelize[] = { sum_exp_tv2, bcast_sum_tv3, output_tv4, sum_exp_rf_tv5}; for (auto tv : tensors_to_parallelize) { tv->axis(0)->parallelize(ParallelType::BIDx); tv->axis(1)->parallelize(ParallelType::BIDy); tv->axis(-1)->parallelize(ParallelType::TIDx); } auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({dimx, dimy, dimz}, options); at::Tensor cg_output = at::empty({dimx, dimy, dimz}, options); at::Tensor t3_output = at::empty_like(cg_output, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({t0}, {cg_output}); auto t2 = at::_softmax(t0, -1, false); TORCH_CHECK( t2.allclose(cg_output, 1e-5, 1e-5), "Error of: ", t2.sub(cg_output).abs().max()); } // Softmax with a 3D tensor with input normalization. void testGPU_FusionSoftmax3DNormalized() { Fusion fusion; FusionGuard fg(&fusion); const int tidx = 32; const int dimx = 32; const int dimy = 16; const int dimz = 130; // Set up your input tensor views TensorView* input_tv0 = makeDummyTensor(3); fusion.addInput(input_tv0); // Normalize with the max value before computing exp. TensorView* max_val_tv1 = reductionOp(BinaryOpType::Max, {-1}, new Float(0), input_tv0); TensorView* bcast_max_tv2 = broadcast(max_val_tv1, {false, false, true}); TensorView* sub_tv3 = sub(input_tv0, bcast_max_tv2); TensorView* exp_tv4 = unaryOp(UnaryOpType::Exp, sub_tv3); TensorView* sum_exp_tv5 = sum(exp_tv4, {-1}); TensorView* bcast_sum_tv6 = broadcast(sum_exp_tv5, {false, false, true}); // Replicate exp_tv4 as exp_tv4_copy because exp_tv4 is going to be // computed at sum_exp_rf_tv8. TensorView* sub_tv3_copy = sub(input_tv0, bcast_max_tv2); TensorView* exp_tv4_copy = unaryOp(UnaryOpType::Exp, sub_tv3_copy); TensorView* output_tv7 = div(exp_tv4_copy, bcast_sum_tv6); fusion.addOutput(output_tv7); bcast_max_tv2->split(-1, tidx); bcast_sum_tv6->split(-1, tidx); max_val_tv1->split(-1, tidx); TensorView* max_val_rf_tv8 = max_val_tv1->rFactor({-2}); sum_exp_tv5->split(-1, tidx); TensorView* sum_exp_rf_tv9 = sum_exp_tv5->rFactor({-2}); output_tv7->split(-1, tidx); sub_tv3->computeAt(sum_exp_rf_tv9, -1); sub_tv3_copy->computeAt(output_tv7, -1); TensorView* tensors_to_parallelize[] = {max_val_tv1, bcast_max_tv2, sum_exp_tv5, bcast_sum_tv6, output_tv7, max_val_rf_tv8, sum_exp_rf_tv9}; for (auto tv : tensors_to_parallelize) { tv->axis(0)->parallelize(ParallelType::BIDx); tv->axis(1)->parallelize(ParallelType::BIDy); tv->axis(-1)->parallelize(ParallelType::TIDx); } auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({dimx, dimy, dimz}, options); at::Tensor t3_output = at::empty({dimx, dimy, dimz}, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({t0}); auto t2 = at::_softmax(t0, -1, false); TORCH_CHECK( t2.allclose(outputs[0], 1e-5, 1e-5), "Error of: ", t2.sub(outputs[0]).abs().max()); } void testGPU_FusionSoftmaxComputeAt() { Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); fusion.addInput(tv0); auto tv1 = sum(tv0, {1}); auto tv2 = broadcast(tv1, {false, true}); auto tv3 = add(tv0, new Float(1.0)); auto tv4 = mul(tv2, tv3); auto tv5 = sum(tv4, {1}); auto tv6 = broadcast(tv5, {false, true}); auto tv7 = sub(tv6, tv4); fusion.addOutput(tv7); tv1->computeAt(tv7, 1); ASSERT_ANY_THROW(tv1->computeAt(tv7, -1)); } // Similar to FusionReduction but uses grid reduction void testGPU_FusionGridReduction1() { const int gdimx = 32; const int bdimx = 128; Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); fusion.addInput(tv0); // tv1[I0, R1] = tv0[I0, I1] TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0); fusion.addOutput(tv1); TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion."); tv1->split(1, bdimx); // tv1[I0, R1o, R1i{128}] = tv0[I0, I1] tv1->split(1, gdimx); // tv1[I0, R1oo, R1oi{32}, R1i{128}] = tv0[I0, I1] TensorView* tv2 = tv1->rFactor({1}); // tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}] = tv0[I0, I1] // tv1[I0, R1oi{32}, R1i{128}] = tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}] // Incrementally, can print in between for debugging tv0->computeAt(tv2, 1); tv2->computeAt(tv1, 1); // Re do it all at once, because why not. tv0->computeAt(tv1, 1); tv1->axis(0)->parallelize(ParallelType::BIDy); tv1->axis(1)->parallelize(ParallelType::BIDx); tv2->axis(2)->parallelize(ParallelType::BIDx); tv1->axis(-1)->parallelize(ParallelType::TIDx); tv2->axis(-1)->parallelize(ParallelType::TIDx); int numel_x = 10000; int numel_y = 65000; // fusion.printKernel(); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({numel_x, numel_y}, options); at::Tensor cg_output = at::empty({numel_x}, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({input}, {cg_output}); auto aten_output = input.sum({1}); TORCH_CHECK(aten_output.allclose(cg_output)); } // Same test as the above but uses BIDy and TIDx for reduction void testGPU_FusionGridReduction2() { const int gdimy = 32; const int bdimx = 128; Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); fusion.addInput(tv0); // tv1[I0, R1] = tv0[I0, I1] TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0); fusion.addOutput(tv1); TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion."); tv1->split(1, bdimx); // tv1[I0, R1o, R1i{128}] = tv0[I0, I1] tv1->split(1, gdimy); // tv1[I0, R1oo, R1oi{32}, R1i{128}] = tv0[I0, I1] TensorView* tv2 = tv1->rFactor({1}); // tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}] = tv0[I0, I1] // tv1[I0, R1oi{32}, R1i{128}] = tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}] // Incrementally, can print in between for debugging tv0->computeAt(tv2, 1); tv2->computeAt(tv1, 1); // Re do it all at once, because why not. tv0->computeAt(tv1, 1); tv1->axis(0)->parallelize(ParallelType::BIDx); tv1->axis(1)->parallelize(ParallelType::BIDy); tv2->axis(2)->parallelize(ParallelType::BIDy); tv1->axis(-1)->parallelize(ParallelType::TIDx); tv2->axis(-1)->parallelize(ParallelType::TIDx); int numel_x = 10000; int numel_y = 65000; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({numel_x, numel_y}, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({input}); auto aten_output = input.sum({1}); TORCH_CHECK(aten_output.allclose(outputs[0])); } // Same test but uses BIDy and BIDz for reduction. No TID used. void testGPU_FusionGridReduction3dim1() { const int gdimz = 32; const int gdimy = 128; Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); fusion.addInput(tv0); // tv1[I0, R1] = tv0[I0, I1] TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0); fusion.addOutput(tv1); TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion."); tv1->split(1, gdimy); // tv1[I0, R1o, R1i{128}] = tv0[I0, I1] tv1->split(1, gdimz); // tv1[I0, R1oo, R1oi{32}, R1i{128}] = tv0[I0, I1] TensorView* tv2 = tv1->rFactor({1}); // tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}] = tv0[I0, I1] // tv1[I0, R1oi{32}, R1i{128}] = tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}] // Incrementally, can print in between for debugging tv0->computeAt(tv2, 1); tv2->computeAt(tv1, 1); // Re do it all at once, because why not. tv0->computeAt(tv1, 1); tv1->axis(0)->parallelize(ParallelType::BIDx); tv1->axis(1)->parallelize(ParallelType::BIDz); tv2->axis(2)->parallelize(ParallelType::BIDz); tv1->axis(-1)->parallelize(ParallelType::BIDy); tv2->axis(-1)->parallelize(ParallelType::BIDy); int numel_x = 100; int numel_y = 6500; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({numel_x, numel_y}, options); at::Tensor cg_output = at::empty({numel_x}, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({input}, {cg_output}); auto aten_output = input.sum({1}); TORCH_CHECK(aten_output.allclose(cg_output)); } // Same as testGPU_FusionGridReduction3dim1 but reduces dimension 0 void testGPU_FusionGridReduction3dim0() { const int rdim = 0; const int gdimy = 128; const int gdimz = 32; Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); fusion.addInput(tv0); // tv1[R0, I1] = tv0[I0, I1] TensorView* tv1 = reductionOp(BinaryOpType::Add, {rdim}, new Float(0), tv0); fusion.addOutput(tv1); TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion."); tv1->split(rdim, gdimy); // tv1[R0o, R0i{128}, I1] = tv0[I0, I1] tv1->split(rdim, gdimz); // tv1[R0oo, R0oi{32}, R0i{128}, I1] = tv0[I0, I1] TensorView* tv2 = tv1->rFactor({rdim}); // tv2[R0oo, I0oi{32}, I0i{128}, I1] = tv0[I0, I1] // tv1[ R0oi{32}, R0i{128}, I1] = tv2[R0oo, I0oi{32}, I0i{128}, I1] // Note that computeAt isn't going to make anything better as there // is no dynamically sized dimension. // Map parallelism as [Serial, BIDz, BIDy, BIDx] tv1->axis(-1)->parallelize(ParallelType::BIDx); tv2->axis(-1)->parallelize(ParallelType::BIDx); tv1->axis(-2)->parallelize(ParallelType::BIDy); tv2->axis(-2)->parallelize(ParallelType::BIDy); tv1->axis(-3)->parallelize(ParallelType::BIDz); tv2->axis(-3)->parallelize(ParallelType::BIDz); int numel_x = 6500; int numel_y = 100; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({numel_x, numel_y}, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({input}); auto aten_output = input.sum({0}); TORCH_CHECK(aten_output.allclose(outputs[0])); } // This is similar to the FusionReduction, but swaps BIDx and TIDx void testGPU_FusionGridReduction4() { Fusion fusion; FusionGuard fg(&fusion); const int bdimx = 128; const int gdimx = 1024; // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); fusion.addInput(tv0); // tv1[I0, R1] = tv0[I0, I1] TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0); fusion.addOutput(tv1); TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion."); tv1->split(1, gdimx); // tv1[I0, R1o, R1i{1024}] = tv0[I0, I1] tv1->split(1, 4); // tv1[I0, R1oo, R1oi{4}, R1i{128}] = tv0[I0, I1] TensorView* tv2 = tv1->rFactor({1}); // tv2[I0, R1oo, Ir1oi{4}, Ir1i{1024}] = tv0[I0, I1] // tv1[I0, R1oi{4}, R1i{1024}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{1024}] TensorView* tv3 = tv1->rFactor({1}); // tv2[I0, R1oo, Ir1oi{4}, Ir1i{1024}] = tv0[I0, I1] // tv3[I0, R1oi{4}, Ir1i{1024}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{1024}] // tv1[I0, R1i{1024}] = tv3[I0, R1oi{4}, Ir1i{1024}] // Incrementally, can print in between for debugging tv0->computeAt(tv2, 1); tv2->computeAt(tv3, 1); tv3->computeAt(tv1, 1); // Re do it all at once, because why not. tv0->computeAt(tv1, 1); tv2->axis(2)->parallelize(ParallelType::Unroll); tv1->axis(0)->parallelize(ParallelType::TIDx); tv1->axis(-1)->parallelize(ParallelType::BIDx); tv2->axis(-1)->parallelize(ParallelType::BIDx); tv3->axis(-1)->parallelize(ParallelType::BIDx); int numel_x = bdimx; int numel_y = 65000; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({numel_x, numel_y}, options); at::Tensor cg_output = at::empty({numel_x}, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({input}, {cg_output}); auto aten_output = input.sum({1}); TORCH_CHECK(aten_output.allclose(cg_output)); } // Grid reduction with 2D thread blocks but only TIDx and BIDx are // mapped to a reduction dim void testGPU_FusionGridReduction5() { Fusion fusion; FusionGuard fg(&fusion); const int bdimx = 64; const int bdimy = 16; const int gdimx = 4; // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); fusion.addInput(tv0); // tv1[I0, R1] = tv0[I0, I1] TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0); fusion.addOutput(tv1); TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion."); tv1->split(1, bdimx); // tv1[I0, R1o, R1i{64}] = tv0[I0, I1] tv1->split(1, gdimx); // tv1[I0, R1oo, R1oi{4}, R1i{64}] = tv0[I0, I1] TensorView* tv2 = tv1->rFactor({1}); // tv2[I0, R1oo, Ir1oi{4}, Ir1i{64}] = tv0[I0, I1] // tv1[I0, R1oi{4}, R1i{64}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{64}] tv0->computeAt(tv1, 1); tv1->axis(-1)->parallelize(ParallelType::TIDx); tv2->axis(-1)->parallelize(ParallelType::TIDx); tv1->axis(-2)->parallelize(ParallelType::BIDx); tv2->axis(-2)->parallelize(ParallelType::BIDx); tv1->axis(0)->parallelize(ParallelType::TIDy); int numel_x = bdimy; int numel_y = 6500; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({numel_x, numel_y}, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({input}); auto aten_output = input.sum({1}); TORCH_CHECK(aten_output.allclose(outputs[0])); } // Similar to FusionGridReduction1 but with 3D tensors void testGPU_FusionGridReduction6() { Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(3); fusion.addInput(tv0); // tv1[I0, R1, R2] = tv0[I0, I1, I2] TensorView* tv1 = reductionOp(BinaryOpType::Add, {1, 2}, new Float(0), tv0); fusion.addOutput(tv1); TORCH_CHECK(fusion.hasReduction(), "Could not detect reduction in fusion."); // Splitting for TID tv1->split(2, 128); // tv1[I0, R1, R2o, R2i{128}] = tv0[I0, I1, I2] // Splitting for BID tv1->split(1, 128); // tv1[I0, R1o, R1i{128}, R2o, R2i{128}] = tv0[I0, I1, I2] TensorView* tv2 = tv1->rFactor({3}); // tv2[I0, I1o, I1i{128}, R2o, I2i{128}] // tv1[I0, R1o, R1i{128}, R2i{128}] TensorView* tv3 = tv1->rFactor({1}); // tv2[I0, I1o, I1i{128}, R2o, I2i{128}] // tv3[I0, R1o, I1i{128}, I2i{128}] // tv1[I0, R1i{128}, R2i{128}] tv3->computeAt(tv1, 1); tv2->computeAt(tv3, 3); tv1->axis(0)->parallelize(ParallelType::BIDy); tv1->axis(-1)->parallelize(ParallelType::TIDx); tv2->axis(-1)->parallelize(ParallelType::TIDx); tv3->axis(-1)->parallelize(ParallelType::TIDx); tv1->axis(-2)->parallelize(ParallelType::BIDx); tv2->axis(-3)->parallelize(ParallelType::BIDx); tv3->axis(-2)->parallelize(ParallelType::BIDx); int numel_x = 6500; int numel_y = 200; int numel_z = numel_y; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({numel_x, numel_y, numel_z}, options); at::Tensor cg_output = at::empty({numel_x}, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({input}, {cg_output}); auto aten_output = input.sum({1, 2}); TORCH_CHECK(aten_output.allclose(cg_output)); } void testGPU_FusionNonRedAxisBind() { int bid_x = 3; int tid_x = 2; int red_dim = 0; Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); fusion.addInput(tv0); TensorView* tv1 = reductionOp(BinaryOpType::Add, {red_dim}, new Float(0), tv0); fusion.addOutput(tv1); tv1->split(-1, tid_x); tv1->axis(-2)->parallelize(ParallelType::BIDx); tv1->axis(-1)->parallelize(ParallelType::TIDx); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({16, bid_x * tid_x}, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({input}); auto aten_output = input.sum({red_dim}); TORCH_CHECK( aten_output.allclose(outputs[0]), "Error of: ", aten_output.sub(outputs[0]).abs().max()); } void testGPU_FusionSplitBCast() { Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views TensorView* input_tv0 = makeDummyTensor(3); TensorView* input_tv1 = makeDummyTensor(3); fusion.addInput(input_tv0); fusion.addInput(input_tv1); TensorView* sum_tv2 = reductionOp(BinaryOpType::Add, {2}, new Float(0), input_tv0); TensorView* bcast_tv3 = broadcast(sum_tv2, {false, false, true}); TensorView* output_tv4 = div(input_tv1, bcast_tv3); sum_tv2->split(-1, 32); TensorView* sum_rf_tv5 = sum_tv2->rFactor({-2}); bcast_tv3->split(-1, 32); output_tv4->split(-1, 32); sum_rf_tv5->axis(0)->parallelize(ParallelType::BIDx); sum_tv2->axis(0)->parallelize(ParallelType::BIDx); bcast_tv3->axis(0)->parallelize(ParallelType::BIDx); output_tv4->axis(0)->parallelize(ParallelType::BIDx); sum_rf_tv5->axis(1)->parallelize(ParallelType::BIDy); sum_tv2->axis(1)->parallelize(ParallelType::BIDy); bcast_tv3->axis(1)->parallelize(ParallelType::BIDy); output_tv4->axis(1)->parallelize(ParallelType::BIDy); sum_rf_tv5->axis(-1)->parallelize(ParallelType::TIDx); sum_tv2->axis(-1)->parallelize(ParallelType::TIDx); bcast_tv3->axis(-1)->parallelize(ParallelType::TIDx); output_tv4->axis(-1)->parallelize(ParallelType::TIDx); fusion.addOutput(output_tv4); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({32, 32, 128}, options); at::Tensor t1 = at::randn({32, 32, 128}, options); at::Tensor cg_output = at::empty({32, 32, 128}, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({t0, t1}, {cg_output}); } void testGPU_FusionBCastInnerDim() { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeDummyTensor(2); fusion.addInput(tv0); // reduce then broadcast auto tv1 = sum(tv0, {0}); auto tv2 = broadcast(tv1, {false, true}); TORCH_CHECK(!tv2->axis(0)->isReduction() && tv2->axis(1)->isBroadcast()); } void testGPU_FusionBCastReduce() { Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); auto tv1 = broadcast(tv0, {true, false, false}); auto tv2 = sum(tv1, {1}); TORCH_CHECK( tv2->axis(0)->isBroadcast() && tv2->axis(1)->isReduction() && !tv2->axis(2)->isBroadcast() && !tv2->axis(2)->isReduction()); } // Multiple consumer reduction with computeAt // https://github.com/csarofeen/pytorch/issues/110 void testGPU_FusionReductionMultiConsumer() { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeDummyTensor(2); fusion.addInput(tv0); auto tv1 = unaryOp(UnaryOpType::Exp, tv0); auto tv2 = reductionOp(BinaryOpType::Max, {-1}, new Float(0), tv1); auto tv3 = reductionOp(BinaryOpType::Min, {-1}, new Float(0), tv1); auto tv4 = add(tv2, tv3); fusion.addOutput(tv4); tv1->computeAt(tv2, -1); TORCH_CHECK( (tv1->getComputeAtView() == tv2 || tv1->getComputeAtView() == tv3) && tv1->getThisComputeAtAxis() == 2 && tv1->getRelativeComputeAtAxis() == 2); } void testGPU_FusionComputeAtExprOrder() { { for (int i = 0; i < 2; ++i) { Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(1); fusion.addInput(tv0); auto tv1 = add(tv0, new Float(1)); auto tv2 = add(tv0, new Float(1)); TensorView* tv3 = add(tv1, tv2); if (i == 0) { tv1->computeAt(tv3, -1); fusion.addOutput(tv2); } else { tv2->computeAt(tv3, -1); fusion.addOutput(tv1); } fusion.addOutput(tv3); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({100}, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({input}); auto aten_output = (input + 1) * 2; TORCH_CHECK( aten_output.allclose(outputs[1]), "Error of: ", aten_output.sub(outputs[1]).abs().max()); } } { Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); fusion.addInput(tv0); auto tv1 = add(tv0, new Float(1)); auto tv2 = add(tv0, new Float(1)); TensorView* tv3 = add(tv1, tv2); fusion.addOutput(tv3); tv3->split(-1, 32); tv1->computeAt(tv3, -1); tv2->computeAt(tv3, -2); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({100, 100}, options); at::Tensor output = at::empty_like(input, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({input}, {output}); auto aten_output = (input + 1) * 2; TORCH_CHECK( aten_output.allclose(output), "Error of: ", aten_output.sub(output).abs().max()); } } void testGPU_FusionZeroDimComputeAt() { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeDummyTensor(1); fusion.addInput(tv0); auto tv1 = sum(tv0, {0}); auto tv2 = add(tv1, new Float(1)); fusion.addOutput(tv2); TORCH_CHECK(tv2->nDims() == 0); tv1->computeAt(tv2, 0); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({100}, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({input}); auto aten_output = input.sum() + 1; TORCH_CHECK( aten_output.allclose(outputs[0]), "Error of: ", aten_output.sub(outputs[0]).abs().max()); } void testGPU_FusionZeroDimBroadcast() { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeDummyTensor(0); fusion.addInput(tv0); auto tv1 = broadcast(tv0, {true, true}); TORCH_CHECK(tv1->nDims() == 2); TensorView* tv2 = makeDummyTensor(2); fusion.addInput(tv2); auto tv3 = add(tv1, tv2); auto tv4 = sum(tv3, {0, 1}); fusion.addOutput(tv4); tv3->computeAt(tv4, -1); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input1 = at::rand({}, options); at::Tensor input2 = at::rand({10, 10}, options); at::Tensor output = at::empty({}, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({input1, input2}, {output}); auto aten_output = (input1.unsqueeze(-1).unsqueeze(-1).expand({10, 10}) + input2).sum(); TORCH_CHECK( aten_output.allclose(output), "Error of: ", aten_output.sub(output).abs().max()); } void testGPU_FusionZeroDimReduction() { Fusion fusion; FusionGuard fg(&fusion); const int bdimx = 32; const int gdimx = 32; TensorView* tv0 = makeDummyTensor(1); fusion.addInput(tv0); auto tv1 = sum(tv0, {0}); fusion.addOutput(tv1); tv1->split(0, bdimx); tv1->split(0, gdimx); auto tv2 = tv1->rFactor({0}); tv1->axis(-1)->parallelize(ParallelType::TIDx); tv2->axis(-1)->parallelize(ParallelType::TIDx); tv1->axis(-2)->parallelize(ParallelType::BIDx); tv2->axis(-2)->parallelize(ParallelType::BIDx); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({1000}, options); at::Tensor output = at::empty({}, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({input}, {output}); auto aten_output = input.sum(); TORCH_CHECK( aten_output.allclose(output), "Error of: ", aten_output.sub(output).abs().max()); } void testGPU_FusionBCastAfterReduce() { Fusion fusion; FusionGuard fg(&fusion); const int tidx = 128; // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); fusion.addInput(tv0); auto tv1 = sum(tv0, {1}); auto tv2 = broadcast(tv1, {false, true}); tv1->split(1, tidx); auto tv3 = tv1->rFactor({-2}); TensorView* tv4 = makeDummyTensor(2); fusion.addInput(tv4); auto tv5 = add(tv2, tv4); fusion.addOutput(tv5); tv5->split(1, tidx); tv3->computeAt(tv5, 1); tv2->split(1, tidx); tv1->axis(-1)->parallelize(ParallelType::TIDx); tv2->axis(-1)->parallelize(ParallelType::TIDx); tv3->axis(-1)->parallelize(ParallelType::TIDx); tv5->axis(-1)->parallelize(ParallelType::TIDx); tv5->axis(0)->parallelize(ParallelType::BIDx); int x = 63, y = 200; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({x, y}, options); at::Tensor t4 = at::randn({x, y}, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({t0, t4}); auto t3 = t0.sum({1}).unsqueeze(-1).expand({x, y}); auto t5 = t3.add(t4); // Error is larger than the default threshold TORCH_CHECK(t5.allclose(outputs[0], 1e-5, 1e-5)); } void testGPU_FusionReductionScheduler() { constexpr int bid_x = 80; constexpr int tid_x = 4096; constexpr int red_dim = 1; Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); fusion.addInput(tv0); TensorView* tv1 = reductionOp(BinaryOpType::Add, {red_dim}, new Float(0), tv0); fusion.addOutput(tv1); const auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({bid_x, tid_x}, options); // Apply reduction heuristic const at::ArrayRef inputs({input}); TORCH_CHECK( cuda::scheduleReduction(&fusion, inputs, tv1), "Reduction schedule was not generated!"); cuda::FusionExecutor fe; fe.compileFusion(&fusion); // no broadcasting needed, omitting the last optional argument; auto outputs = fe.runFusion({input}); auto aten_output = input.sum({red_dim}); TORCH_CHECK( aten_output.allclose(outputs[0]), "Error of: ", aten_output.sub(outputs[0]).abs().max()); } // Simple reduction parallelized on a symbolic size. void testGPU_FusionSymbolicReduction() { Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); fusion.addInput(tv0); // tv1[I0, R1] = tv0[I0, I1] TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0); fusion.addOutput(tv1); // Interface should just be a direct split with a Parallel type. We can // include the parallelize call if we do this. tv1->split(1, NamedScalar::getParallelDim(ParallelType::TIDx)); // tv1[I0, R1o, R1i{BIDx}] = tv0[I0, I1] TensorView* tv2 = tv1->rFactor({1}); // tv2[I0, R1oo, Ir1oi{4}, Ir1i{BIDx}] = tv0[I0, I1] // tv1[I0, R1oi{4}, R1i{BIDx}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{BIDx}] // Incrementally, can print in between for debugging tv0->computeAt(tv2, 1); tv2->computeAt(tv1, 1); tv2->axis(-1)->parallelize(ParallelType::TIDx); tv1->axis(0)->parallelize(ParallelType::BIDx); tv1->axis(-1)->parallelize(ParallelType::TIDx); int numel_x = 65000; int numel_y = 1025; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({numel_x, numel_y}, options); // How many threads to use for the block reduction int runtime_threadIdx_dim = 128; torch::jit::fuser::cuda::FusionExecutor executor; executor.compileFusion(&fusion); auto outputs = executor.runFusion( {input}, torch::jit::fuser::cuda::LaunchParams( -1, -1, -1, runtime_threadIdx_dim, -1, -1)); auto aten_output = input.sum({1}); TORCH_CHECK(aten_output.allclose(outputs[0])); } void testGPU_FusionReductionSchedulerMultiDimNonFastest() { const std::vector red_dims = {0, 2}; // Copy is because CodeGen requires int and Pytorch requires int64_t // for a vector of reduction dimensions const std::vector red_dims64 = {0, 2}; const std::vector tensor_dims_in = {5, 10, 15, 20}; const std::vector tensor_dims_out = {10, 20}; Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(tensor_dims_in.size()); fusion.addInput(tv0); TensorView* tv1 = reductionOp(BinaryOpType::Add, red_dims, new Float(0), tv0); fusion.addOutput(tv1); const auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand(tensor_dims_in, options); at::Tensor cg_output = at::empty(tensor_dims_out, options); // Apply reduction heuristic const at::ArrayRef inputs({input}); TORCH_CHECK( cuda::scheduleReduction(&fusion, inputs, tv1), "Reduction schedule was not generated!"); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({input}); auto aten_output = input.sum(red_dims64); TORCH_CHECK( aten_output.allclose(outputs[0]), "Error of: ", aten_output.sub(outputs[0]).abs().max()); } void testGPU_FusionReductionSchedulerMultiDimFastest() { const std::vector red_dims = {1, 3}; // Copy is because CodeGen requires int and Pytorch requires int64_t // for a vector of reduction dimensions const std::vector red_dims64 = {1, 3}; const std::vector tensor_dims_in = {5, 10, 15, 20}; const std::vector tensor_dims_out = {5, 15}; Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(tensor_dims_in.size()); fusion.addInput(tv0); TensorView* tv1 = reductionOp(BinaryOpType::Add, red_dims, new Float(0), tv0); fusion.addOutput(tv1); const auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand(tensor_dims_in, options); TORCH_CHECK( cuda::scheduleReduction(&fusion, {input}, tv1), "Reduction schedule was not generated!"); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({input}); auto aten_output = input.sum(red_dims64); TORCH_CHECK( aten_output.allclose(outputs[0]), "Error of: ", aten_output.sub(outputs[0]).abs().max()); } void testGPU_FusionReductionSchedulerDimShmoo() { std::vector fp16_usage = {false}; std::vector red_axis = {1, 0}; std::vector output_dims = {320, 640}; std::vector red_dims; // Tried to cut down the number iterations with just // doing every other power of 2. for (int i = 1; i <= 1024 * 1024; i <<= 2) { red_dims.push_back(i); } for (auto fp16 : fp16_usage) { for (auto& axis : red_axis) { for (auto& odim : output_dims) { for (auto& rdim : red_dims) { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeDummyTensor(2, (fp16 ? DataType::Half : DataType::Float)); fusion.addInput(tv0); torch::jit::fuser::Val* tv0_cast = nullptr; if (fp16) { tv0_cast = castOp(DataType::Float, tv0); } TensorView* tv1 = reductionOp( BinaryOpType::Add, {axis}, new Float(0), (fp16 ? tv0_cast->as() : tv0)); TensorView* tv1_cast = nullptr; if (fp16) { tv1_cast = castOp(DataType::Half, tv1); } fusion.addOutput((fp16 ? tv1_cast : tv1)); auto options = at::TensorOptions() .dtype((fp16 ? at::kHalf : at::kFloat)) .device(at::kCUDA, 0); at::Tensor input = (axis ? at::rand({odim, rdim}, options) : at::rand({rdim, odim}, options)); const at::ArrayRef inputs({input}); c10::optional rparams = cuda::scheduleReduction(&fusion, inputs, tv1); TORCH_CHECK(rparams != c10::nullopt, "Reduction is not found!"); if (fp16) { if (axis == 0) { int tidx = rparams.value().lparams.bdimx(); tv1_cast->split(-1, tidx); tv1_cast->axis(-1)->parallelize(ParallelType::TIDx); tv1_cast->axis(-2)->parallelize(ParallelType::BIDx); } else { if (rparams.value().mul_reds_per_blk) { int tidy = rparams.value().lparams.bdimy(); tv1_cast->split(0, tidy); tv1_cast->axis(-1)->parallelize(ParallelType::TIDy); } tv1_cast->axis(0)->parallelize(ParallelType::BIDx); } } torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto cg_output = fe.runFusion({input}); auto aten_output = input.sum({axis}); TORCH_CHECK( aten_output.allclose(cg_output[0]), "Error of: ", aten_output.sub(cg_output[0]).abs().max()); } } } } } void testGPU_FusionCacheBefore() { // TVM Cache Write Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeDummyTensor(2); TensorView* tv1 = add(tv0, new Float(1.0)); TensorView* tv2 = mul(tv1, new Float(3.0)); fusion.addInput(tv0); fusion.addOutput(tv2); // Before: TV2 = TV1 * 3 // After: TV3 = TV1 * 3; // TV2 = TV3; constexpr int BSX = 32; tv2->split(-1, BSX); tv0->computeAt(tv2, -1); // cache_before automatically applies ComputeAt to the cache TensorView tv2->cache_before(); // Thread and Block binding tv2->axis(0)->parallelize(ParallelType::BIDx); tv2->axis(-1)->parallelize(ParallelType::TIDx); constexpr int M = 32, N = 750; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({M, N}, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({input}); at::Tensor aten_output = (input + 1.0) * 3.0; TORCH_CHECK( aten_output.allclose(outputs[0], 1e-5, 1e-5), "Error of: ", aten_output.sub(outputs[0]).abs().sum()); } void testGPU_FusionCacheAfter() { // TVM Cache Read Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeDummyTensor(2); TensorView* tv1 = add(tv0, new Float(1.0)); TensorView* tv2 = mul(tv1, new Float(3.0)); fusion.addInput(tv0); fusion.addOutput(tv2); // Before: TV1 = TV0 + 1 // After: TV3 = TV0; // TV1 = TV3 + 1 constexpr int BSX = 32; tv2->split(-1, BSX); tv0->computeAt(tv2, -1); // cache_after automatically applies ComputeAt to the cache TensorView tv0->cache_after(); // Thread and Block binding tv2->axis(0)->parallelize(ParallelType::BIDx); tv2->axis(-1)->parallelize(ParallelType::TIDx); constexpr int M = 32, N = 457; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({M, N}, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({input}); at::Tensor aten_output = (input + 1.0) * 3.0; TORCH_CHECK( aten_output.allclose(outputs[0], 1e-5, 1e-5), "Error of: ", aten_output.sub(outputs[0]).abs().sum()); } void testGPU_FusionCacheIndirect() { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeDummyTensor(2); TensorView* tv1 = makeDummyTensor(2); TensorView* tv2 = makeDummyTensor(2); TensorView* tv3 = makeDummyTensor(2); TensorView* tv4 = sub(tv2, tv3); TensorView* tv5 = add(tv1, tv4); TensorView* tv6 = sub(tv5, tv0); fusion.addInput(tv0); fusion.addInput(tv1); fusion.addInput(tv2); fusion.addInput(tv3); fusion.addOutput(tv6); // t6 = ((t1 + (t2 - t3)) - t0) // cache_after on inputs placed before schedule constexpr int BSX = 32; tv6->split(-1, BSX); tv2->computeAt(tv6, -1); tv5->cache_after(); tv5->cache_before(); // Thread and Block binding tv6->axis(0)->parallelize(ParallelType::BIDx); tv6->axis(-1)->parallelize(ParallelType::TIDx); constexpr int M = 32, N = 810; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor in0 = at::rand({M, N}, options); at::Tensor in1 = at::rand({M, N}, options); at::Tensor in2 = at::rand({M, N}, options); at::Tensor in3 = at::rand({M, N}, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({in0, in1, in2, in3}); at::Tensor aten_output = (in1 + (in2 - in3)) - in0; TORCH_CHECK( aten_output.allclose(outputs[0], 1e-5, 1e-5), "Error of: ", aten_output.sub(outputs[0]).abs().sum()); } void testGPU_FusionCacheBcast() { Fusion fusion; FusionGuard fg(&fusion); // Algorithm TensorView* tv0 = makeDummyTensor(1); // (M, 1) TensorView* tv1 = broadcast(tv0, {false, true}); TensorView* tv2 = makeDummyTensor(1); // (1, N) TensorView* tv3 = broadcast(tv2, {true, false}); TensorView* tv4 = mul(tv1, tv3); fusion.addInput(tv0); fusion.addInput(tv2); fusion.addOutput(tv4); constexpr int BSX = 128; tv4->split(0, BSX); tv4->split(-1, BSX); tv4->reorder({{0, 0}, {1, 2}, {2, 1}, {3, 3}}); // M/BSX, N/BSY, BSX, BSY tv0->computeAt(tv4, 2); tv2->computeAt(tv4, 2); // 0, 1 | 2, 3, 4 // Case 1 tv0->cache_after(); // Case 2 tv1->cache_before(); // Case 3 tv1->cache_after(); // Case 4 TensorView* tv8 = tv4->cache_before(); tv4->axis(0)->parallelize(ParallelType::BIDx); tv4->axis(1)->parallelize(ParallelType::BIDy); tv4->axis(-1)->parallelize(ParallelType::TIDx); // Manual Replay on TV3 tv3->axis(-1)->parallelize(ParallelType::TIDx); tv8->axis(-1)->parallelize(ParallelType::TIDx); constexpr int M = 92, N = 500; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({M}, options); at::Tensor t1 = at::randn({N}, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({t0, t1}); at::Tensor aten_output = t0.unsqueeze(1).matmul(t1.unsqueeze(0)); TORCH_CHECK( aten_output.allclose(outputs[0], 1e-5, 1e-5), "Error of: ", aten_output.sub(outputs[0]).abs().max()); } void testGPU_FusionCacheComplex() { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeDummyTensor(2); // (N, N) TensorView* tv1 = makeDummyTensor(1); // (N) TensorView* tv2 = sum(tv0, {1}); // (N) TensorView* tv3 = broadcast(tv2, {false, true}); // (N, 1) TensorView* tv4 = broadcast(tv1, {true, false}); // (1, N) TensorView* tv5 = mul(tv3, tv4); // (N, N) fusion.addInput(tv0); fusion.addInput(tv1); fusion.addOutput(tv5); // Exception: Cache-Before on reduction Op // TensorView* tv9 = tv2->cache_before(); constexpr int BSX = 128; tv5->split(0, BSX); tv5->split(-1, BSX); // M/BSX, BSX, N/BSX, BSX tv5->reorder({{0, 0}, {1, 2}, {2, 1}, {3, 3}}); // M/BSX, N/BSY, BSX, BSY tv0->computeAt(tv5, 2); tv1->computeAt(tv5, 2); // 0, 1 | 2, 3, 4 tv2->cache_after(); TensorView* tv7 = tv5->cache_before(); tv5->axis(0)->parallelize(ParallelType::BIDx); tv5->axis(1)->parallelize(ParallelType::BIDy); tv5->axis(-1)->parallelize(ParallelType::TIDx); tv4->axis(-1)->parallelize(ParallelType::TIDx); tv7->axis(-1)->parallelize(ParallelType::TIDx); constexpr int N = 800; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input1 = at::rand({N, N}, options); at::Tensor input2 = at::rand({N}, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({input1, input2}); at::Tensor aten_output = matmul(sum(input1, 1).unsqueeze(1), input2.unsqueeze(0)); TORCH_CHECK( aten_output.allclose(outputs[0], 1e-5, 1e-5), "Error of: ", aten_output.sub(outputs[0]).abs().sum()); } void testGPU_FusionCacheMultiConsumer() { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeDummyTensor(1); TensorView* tv1 = add(tv0, new Float(1)); TensorView* tv2 = add(tv1, new Float(2)); TensorView* tv3 = add(tv0, new Float(1)); TensorView* tv4 = add(tv3, new Float(2)); fusion.addInput(tv0); fusion.addOutput(tv2); fusion.addOutput(tv4); tv1->computeAt(tv2, -1); tv3->computeAt(tv4, -1); auto tv5 = tv1->cache_before(); auto tv6 = tv3->cache_before(); tv5->setMemoryType(MemoryType::Shared); tv6->setMemoryType(MemoryType::Shared); // Fails because tensor must be recomputed twice // auto tv7 = tv0->cache_after(); constexpr int N = 800; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({N}, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({input}); auto aten_output = (input + 1) + 2; TORCH_CHECK( aten_output.allclose(outputs[0], 1e-5, 1e-5), "Error of: ", aten_output.sub(outputs[0]).abs().sum()); TORCH_CHECK( aten_output.allclose(outputs[1], 1e-5, 1e-5), "Error of: ", aten_output.sub(outputs[1]).abs().sum()); } void testGPU_FusionSmem() { Fusion fusion; FusionGuard fg(&fusion); // Algorithm TensorView* tv0 = makeDummyTensor(2); // (M, N) TensorView* tv1 = makeDummyTensor(2); // (M, N) TensorView* tv2 = mul(tv0, tv1); fusion.addInput(tv0); fusion.addInput(tv1); fusion.addOutput(tv2); // Schedule TensorView* tv3 = tv0->cache_after(); TensorView* tv4 = tv1->cache_after(); tv3->setMemoryType(MemoryType::Shared); tv4->setMemoryType(MemoryType::Shared); constexpr int BSY = 32; constexpr int BSX = 128; tv2->split(0, BSY); tv2->split(2, BSX); // M/BSX, BSX, N/BSX, BSX tv2->reorder({{0, 0}, {1, 2}, {2, 1}, {3, 3}}); // M/BSX, N/BSX, BSX, BSX tv0->computeAt(tv2, 2); tv1->computeAt(tv2, 2); // Thread and Block binding tv2->axis(0)->parallelize(ParallelType::BIDx); tv2->axis(1)->parallelize(ParallelType::BIDy); tv2->axis(-1)->parallelize(ParallelType::TIDx); // Manual Binding tv3->axis(-1)->parallelize(ParallelType::TIDx); tv4->axis(-1)->parallelize(ParallelType::TIDx); constexpr int M = 128, N = 10240; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({M, N}, options); at::Tensor t1 = at::randn({M, N}, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({t0, t1}); at::Tensor aten_output = mul(t0, t1); TORCH_CHECK( aten_output.allclose(outputs[0], 1e-5, 1e-5), "Error of: ", aten_output.sub(outputs[0]).abs().max()); } void testGPU_FusionSmemReduce() { Fusion fusion; FusionGuard fg(&fusion); // Algorithm TensorView* tv0 = makeDummyTensor(3); // M, K, N TensorView* tv1 = sum(tv0, {1}); // M, R, N fusion.addInput(tv0); fusion.addOutput(tv1); TensorView* tv2 = tv0->cache_after(); tv2->setMemoryType(MemoryType::Shared); // Schedule constexpr int BSX = 32; tv1->split(2, BSX); tv1->split(1, 128); tv1->split(0, BSX); // M/BSX, BSX, K/BSX, BSX, N/BSX, BSX tv1->reorder({{0, 0}, {1, 2}, {2, 4}, {3, 5}, {4, 1}, {5, 3}}); TensorView* tv3 = tv1->rFactor({-2}); tv0->computeAt(tv1, -2); tv0->computeAt(tv3, -2); // Thread and Block binding tv1->axis(0)->parallelize(ParallelType::BIDx); tv1->axis(1)->parallelize(ParallelType::BIDy); tv1->axis(-1)->parallelize(ParallelType::TIDx); // Manual Binding tv2->axis(-1)->parallelize(ParallelType::TIDx); tv3->axis(-1)->parallelize(ParallelType::TIDx); constexpr int M = 154, K = 45, N = 1524; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({M, K, N}, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({t0}); at::Tensor aten_output = sum(t0, {1}); TORCH_CHECK( aten_output.allclose(outputs[0], 1e-5, 1e-5), "Error of: ", aten_output.sub(outputs[0]).abs().max()); } void testGPU_FusionSmemBlockGemm() { Fusion fusion; FusionGuard fg(&fusion); // Algorithm TensorView* tv0 = makeDummyTensor(2); // (M, K) TensorView* tv1 = makeDummyTensor(2); // (K, N) TensorView* tv2 = broadcast(tv0, {false, false, true}); // (M, K, B) TensorView* tv3 = broadcast(tv1, {true, false, false}); // (B, K, N) TensorView* tv4 = mul(tv2, tv3); // M, K, N TensorView* tv5 = sum(tv4, {1}); // M, R, N fusion.addInput(tv0); fusion.addInput(tv1); fusion.addOutput(tv5); // Schedule constexpr int BSX = 16; tv5->split(2, BSX); tv5->split(1, BSX); tv5->split(0, BSX); // M/BSX, BSX, K/BSX, BSX, N/BSX, BSX tv5->reorder({{0, 0}, {1, 3}, {2, 2}, {3, 5}, {4, 1}, {5, 4}}); // M/BSX, N/BSX, K/BSX, MSX, NSX, KSX TensorView* tv6 = tv5->rFactor({-1}); tv2->setMemoryType(MemoryType::Shared); tv3->setMemoryType(MemoryType::Shared); tv4->setMemoryType(MemoryType::Shared); tv6->setMemoryType(MemoryType::Shared); tv0->computeAt(tv5, 3); tv1->computeAt(tv5, 3); // Thread and Block binding tv5->axis(0)->parallelize(ParallelType::BIDx); tv5->axis(1)->parallelize(ParallelType::BIDy); tv5->axis(-2)->parallelize(ParallelType::TIDy); tv5->axis(-1)->parallelize(ParallelType::TIDx); // Manual Binding tv2->axis(-1)->parallelize(ParallelType::TIDx); tv3->axis(-1)->parallelize(ParallelType::TIDx); tv4->axis(-1)->parallelize(ParallelType::TIDx); tv6->axis(-3)->parallelize(ParallelType::TIDy); tv6->axis(-2)->parallelize(ParallelType::TIDx); constexpr int M = 154, K = 45, N = 1524; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({M, K}, options); at::Tensor t1 = at::randn({K, N}, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({t0, t1}); at::Tensor aten_output = matmul(t0, t1); TORCH_CHECK( aten_output.allclose(outputs[0], 1e-5, 1e-5), "Error of: ", aten_output.sub(outputs[0]).abs().max()); } void testGPU_FusionSmemBlockGemmCache() { #if 0 Fusion fusion; FusionGuard fg(&fusion); // Algorithm TensorView* tv0 = makeDummyTensor(2); // (M, K) TensorView* tv1 = makeDummyTensor(2); // (K, N) TensorView* tv2 = broadcast(tv0, {false, false, true}); // (M, K, B) TensorView* tv3 = broadcast(tv1, {true, false, false}); // (B, K, N) TensorView* tv4 = mul(tv2, tv3); // M, K, N TensorView* tv5 = sum(tv4, {1}); // M, R, N fusion.addInput(tv0); fusion.addInput(tv1); fusion.addOutput(tv5); // Schedule // Remove reduction axis from tv5 // tv6 = (M, R, N) // tv5 = (M, N) TensorView* tv6 = tv5->cache_before(); constexpr int BSX = 16; tv5->split(1, BSX); tv5->split(0, BSX); // M/BSX, BSX, N/BSX, BSX tv5->reorder({{0, 0}, {1, 2}, {2, 1}, {3, 3}}); // tv5 = M/BSX, N/BSX, MSX, NSX tv6->computeAt(tv5, 2); tv6->computeAt(tv5, 2); tv6->split(-1, BSX); // M/BSX, BSX, K/BSX, BSX, N/BSX, BSX tv6->reorder({{0, 0}, {1, 1}, {2, 3}, {3, 4}, {4, 2}, {5, 5}}); // M/BSX, N/BSX, K/BSX, MSX, NSX, KSX TensorView* tv7 = tv6->rFactor({-1}); // tv7 = M/BSX, N/BSX, K/BSXrf, MSX, NSX, KSXr // tv6 = M/BSX, N/BSX, K/BSXr, MSX, NSX tv0->computeAt(tv6, 3); tv1->computeAt(tv6, 3); tv0->computeAt(tv7, 3); tv1->computeAt(tv7, 3); tv2->setMemoryType(MemoryType::Shared); tv3->setMemoryType(MemoryType::Shared); tv4->setMemoryType(MemoryType::Shared); tv6->setMemoryType(MemoryType::Shared); tv7->setMemoryType(MemoryType::Shared); // Memory Type // Thread and Block binding tv5->axis(0)->parallelize(ParallelType::BIDx); tv5->axis(1)->parallelize(ParallelType::BIDy); tv5->axis(-2)->parallelize(ParallelType::TIDy); tv5->axis(-1)->parallelize(ParallelType::TIDx); // Manual Binding tv2->axis(-1)->parallelize(ParallelType::TIDx); tv3->axis(-1)->parallelize(ParallelType::TIDx); tv4->axis(-1)->parallelize(ParallelType::TIDx); tv7->axis(-3)->parallelize(ParallelType::TIDy); tv7->axis(-2)->parallelize(ParallelType::TIDx); tv6->axis(-2)->parallelize(ParallelType::TIDy); tv6->axis(-1)->parallelize(ParallelType::TIDx); constexpr int M = 154, K = 45, N = 1524; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({M, K}, options); at::Tensor t1 = at::randn({K, N}, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({t0, t1}); at::Tensor aten_output = matmul(t0, t1); TORCH_CHECK( aten_output.allclose(outputs[0], 1e-5, 1e-5), "Error of: ", aten_output.sub(outputs[0]).abs().max()); #endif } void testGPU_FusionConstCheck() { Fusion fusion; FusionGuard fg(&fusion); auto one = new Int(1); TORCH_CHECK(one->isConstScalar()); auto one_x2 = mul(one, one); TORCH_CHECK(one_x2->isConstScalar()); auto one_x3 = mul(one_x2, one); TORCH_CHECK(one_x3->isConstScalar()); auto one_x4 = mul(one_x3, one); TORCH_CHECK(one_x4->isConstScalar()); } void testGPU_FusionUnrollWithAlloc() { const std::vector tensor_dims_in = {128, 128}; Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(tensor_dims_in.size()); fusion.addInput(tv0); TensorView* tv1 = add(tv0, new Float(0)); TensorView* tv2 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv1); fusion.addOutput(tv2); const auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand(tensor_dims_in, options); at::Tensor cg_output = at::empty({tensor_dims_in[0]}, options); // const at::ArrayRef inputs({input}); // Schedule tv2->split(1, 32); tv2->split(1, 4); // unroll auto tv2_rf = tv2->rFactor({-3, -2}); tv2->axis(0)->parallelize(ParallelType::BIDx); tv2->axis(-1)->parallelize(ParallelType::TIDx); tv2_rf->axis(0)->parallelize(ParallelType::BIDx); tv2_rf->axis(-1)->parallelize(ParallelType::TIDx); tv2_rf->axis(-2)->parallelize(ParallelType::Unroll); tv1->computeAt(tv2_rf, -1); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({input}); auto aten_output = (input + 0).sum(1); TORCH_CHECK( aten_output.allclose(outputs[0]), "Error of: ", aten_output.sub(outputs[0]).abs().max()); } // Test isZeroInt void testGPU_FusionIsZeroInt() { Fusion fusion; FusionGuard fg(&fusion); Int* x = new Int(0); Int* y = new Int(1); Val* z = mul(x, y); TORCH_CHECK(x->isZeroInt()); TORCH_CHECK(!y->isZeroInt()); TORCH_CHECK(!z->isZeroInt()); } // Test isOneInt void testGPU_FusionIsOneInt() { Fusion fusion; FusionGuard fg(&fusion); Int* x = new Int(1); Int* y = new Int(1); Val* z = mul(x, y); TORCH_CHECK(x->isOneInt()); TORCH_CHECK(y->isOneInt()); TORCH_CHECK(!z->isOneInt()); } // This is to verify no cycle of computeAt is created. A more complex // variation of this pattern appears in one of the Python tests // (test_random_topo). void testGPU_FusionComputeAtNonterminatingOutput() { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeDummyTensor(1); fusion.addInput(tv0); // Common intermediate tensor auto tv1 = add(tv0, new Float(1)); // tv1 -> tv2 auto tv2 = add(tv1, new Float(2)); // tv1 -> tv3 -> tv4 auto tv3 = add(tv1, new Float(3)); auto tv4 = add(tv3, new Float(4)); // NOTE: This should no longer occur as of PR #201. // The order of adding outputs matters. If tv3 is added before tv4, // it should be fine. However, if tv4 is added before tv3, there // will be a cycle of tv3->tv4 and tv4->tv3. tv3->tv4 is created // first, and then tv4->tv3 is created at the final phase of // computeAt (ComputeAt::setupOutputs). fusion.addOutput(tv2); fusion.addOutput(tv4); fusion.addOutput(tv3); tv0->computeAt(tv2, -1); TORCH_CHECK( !(tv3->getComputeAtView() == tv4 && tv4->getComputeAtView() == tv3), "ComputeAt cycle detected between tv3 and tv4"); const auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand(100, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto outputs = fe.runFusion({input}); auto& output_tv2 = outputs[0]; auto& output_tv4 = outputs[1]; auto& output_tv3 = outputs[2]; auto aten_t1 = input + 1; auto aten_t2 = aten_t1 + 2; auto aten_t3 = aten_t1 + 3; auto aten_t4 = aten_t3 + 4; TORCH_CHECK( aten_t2.allclose(output_tv2), "Error of: ", aten_t2.sub(output_tv2).abs().max()); TORCH_CHECK( aten_t3.allclose(output_tv3), "Error of: ", aten_t3.sub(output_tv3).abs().max()); TORCH_CHECK( aten_t4.allclose(output_tv4), "Error of: ", aten_t4.sub(output_tv4).abs().max()); return; } void testGPU_FusionTraversalOrder1() { Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); fusion.addInput(tv0); TensorView* tv1 = add(tv0, new Float(1)); TensorView* tv2 = add(tv0, new Float(2)); TensorView* tv3 = add(tv1, new Float(3)); TensorView* tv4 = add(tv1, new Float(4)); fusion.addOutput(tv2); fusion.addOutput(tv3); fusion.addOutput(tv4); tv1->computeAt(tv3, -1); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({10, 10}, options); at::Tensor cg_output_tv2 = at::empty_like(input, options); at::Tensor cg_output_tv3 = at::empty_like(input, options); at::Tensor cg_output_tv4 = at::empty_like(input, options); fe.runFusion({input}, {cg_output_tv2, cg_output_tv3, cg_output_tv4}); auto t1 = input + 1; auto t2 = input + 2; auto t3 = t1 + 3; auto t4 = t1 + 4; TORCH_CHECK( t2.allclose(cg_output_tv2), "tv2 error of: ", t2.sub(cg_output_tv2).abs().max()); TORCH_CHECK( t3.allclose(cg_output_tv3), "tv5 error of: ", t3.sub(cg_output_tv3).abs().max()); TORCH_CHECK( t4.allclose(cg_output_tv4), "tv4 error of: ", t4.sub(cg_output_tv4).abs().max()); } void testGPU_FusionTraversalOrder2() { Fusion fusion; FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); fusion.addInput(tv0); TensorView* tv1 = add(tv0, new Float(1)); TensorView* tv2 = add(tv1, new Float(2)); TensorView* tv3 = add(tv0, new Float(3)); TensorView* tv4 = add(tv3, new Float(4)); TensorView* tv5 = add(tv1, tv3); fusion.addOutput(tv2); fusion.addOutput(tv4); fusion.addOutput(tv5); tv1->computeAt(tv5, -1); tv3->computeAt(tv5, -1); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({10, 10}, options); at::Tensor cg_output_tv2 = at::empty_like(input, options); at::Tensor cg_output_tv4 = at::empty_like(input, options); at::Tensor cg_output_tv5 = at::empty_like(input, options); fe.runFusion({input}, {cg_output_tv2, cg_output_tv4, cg_output_tv5}); auto t1 = input + 1; auto t2 = t1 + 2; auto t3 = input + 3; auto t4 = t3 + 4; auto t5 = t1 + t3; TORCH_CHECK( t2.allclose(cg_output_tv2), "tv2 error of: ", t2.sub(cg_output_tv2).abs().max()); TORCH_CHECK( t4.allclose(cg_output_tv4), "tv4 error of: ", t4.sub(cg_output_tv4).abs().max()); TORCH_CHECK( t5.allclose(cg_output_tv5), "tv5 error of: ", t5.sub(cg_output_tv5).abs().max()); } void testGPU_FusionTraversalOrder3() { for (int i = 0; i < 2; ++i) { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeDummyTensor(1); fusion.addInput(tv0); TensorView* tv1 = add(tv0, new Float(1)); TensorView* tv2 = add(tv1, new Float(2)); TensorView* tv3 = add(tv0, new Float(3)); TensorView* tv4 = add(tv3, new Float(4)); TensorView* tv5 = add(tv1, tv3); fusion.addOutput(tv2); fusion.addOutput(tv4); fusion.addOutput(tv5); const int tile = 32; tv1->split(-1, tile); tv2->split(-1, tile); tv3->split(-1, tile); tv4->split(-1, tile); tv5->split(-1, tile); auto compute_at_outer = tv1; auto compute_at_inner = tv3; if (i == 1) { std::swap(compute_at_inner, compute_at_outer); } compute_at_outer->computeAt(tv5, -2); compute_at_inner->computeAt(tv5, -1); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({100}, options); at::Tensor cg_output_tv2 = at::empty_like(input, options); at::Tensor cg_output_tv4 = at::empty_like(input, options); at::Tensor cg_output_tv5 = at::empty_like(input, options); fe.runFusion({input}, {cg_output_tv2, cg_output_tv4, cg_output_tv5}); auto t1 = input + 1; auto t2 = t1 + 2; auto t3 = input + 3; auto t4 = t3 + 4; auto t5 = t1 + t3; TORCH_CHECK( t2.allclose(cg_output_tv2), "tv2 error of: ", t2.sub(cg_output_tv2).abs().max()); TORCH_CHECK( t4.allclose(cg_output_tv4), "tv4 error of: ", t4.sub(cg_output_tv4).abs().max()); TORCH_CHECK( t5.allclose(cg_output_tv5), "tv5 error of: ", t5.sub(cg_output_tv5).abs().max()); } } void testGPU_FusionTraversalOrder4() { Fusion fusion; FusionGuard fg(&fusion); // First tree TensorView* tv0 = makeDummyTensor(1); fusion.addInput(tv0); TensorView* tv1 = add(tv0, new Float(1)); TensorView* tv2 = add(tv1, new Float(2)); TensorView* tv3 = add(tv1, new Float(3)); fusion.addOutput(tv2); fusion.addOutput(tv3); // Second tree TensorView* tv4 = makeDummyTensor(1); fusion.addInput(tv4); TensorView* tv5 = add(tv4, new Float(5)); TensorView* tv6 = add(tv5, new Float(6)); TensorView* tv7 = add(tv5, new Float(7)); fusion.addOutput(tv6); fusion.addOutput(tv7); tv1->computeAt(tv2, -1); tv5->computeAt(tv6, -1); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::rand({100}, options); at::Tensor t4 = at::rand_like(t0, options); at::Tensor cg_output_tv2 = at::empty_like(t0, options); at::Tensor cg_output_tv3 = at::empty_like(t0, options); at::Tensor cg_output_tv6 = at::empty_like(t0, options); at::Tensor cg_output_tv7 = at::empty_like(t0, options); fe.runFusion( {t0, t4}, {cg_output_tv2, cg_output_tv3, cg_output_tv6, cg_output_tv7}); auto t1 = t0 + 1; auto t2 = t1 + 2; auto t3 = t1 + 3; auto t5 = t4 + 5; auto t6 = t5 + 6; auto t7 = t5 + 7; TORCH_CHECK( t2.allclose(cg_output_tv2), "tv2 error of: ", t2.sub(cg_output_tv2).abs().max()); TORCH_CHECK( t3.allclose(cg_output_tv3), "tv3 error of: ", t3.sub(cg_output_tv3).abs().max()); TORCH_CHECK( t6.allclose(cg_output_tv6), "tv6 error of: ", t6.sub(cg_output_tv6).abs().max()); TORCH_CHECK( t7.allclose(cg_output_tv7), "tv7 error of: ", t7.sub(cg_output_tv7).abs().max()); } void testGPU_FusionTraversalOrder5() { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeDummyTensor(1); fusion.addInput(tv0); TensorView* tv1 = add(tv0, new Float(1)); TensorView* tv2 = add(tv1, new Float(2)); TensorView* tv3 = add(tv0, new Float(3)); TensorView* tv4 = add(tv3, new Float(4)); TensorView* tv5 = add(tv2, tv4); fusion.addOutput(tv1); fusion.addOutput(tv3); fusion.addOutput(tv5); tv2->computeAt(tv5, -1); tv4->computeAt(tv5, -1); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::rand({100}, options); at::Tensor cg_output_tv1 = at::empty_like(t0, options); at::Tensor cg_output_tv3 = at::empty_like(t0, options); at::Tensor cg_output_tv5 = at::empty_like(t0, options); fe.runFusion({t0}, {cg_output_tv1, cg_output_tv3, cg_output_tv5}); auto t1 = t0 + 1; auto t2 = t1 + 2; auto t3 = t0 + 3; auto t4 = t3 + 4; auto t5 = t2 + t4; TORCH_CHECK( t1.allclose(cg_output_tv1), "tv1 error of: ", t1.sub(cg_output_tv1).abs().max()); TORCH_CHECK( t3.allclose(cg_output_tv3), "tv3 error of: ", t3.sub(cg_output_tv3).abs().max()); TORCH_CHECK( t5.allclose(cg_output_tv5), "tv5 error of: ", t5.sub(cg_output_tv5).abs().max()); } void testGPU_FusionTraversalOrder6() { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeDummyTensor(1); fusion.addInput(tv0); TensorView* tv1 = add(tv0, new Float(1)); TensorView* tv2 = add(tv0, new Float(2)); TensorView* tv3 = add(tv1, tv2); TensorView* tv4 = add(tv3, new Float(4)); fusion.addOutput(tv4); tv1->split(0, 32); tv2->split(0, 32); tv3->split(0, 32); tv4->split(0, 32); tv3->computeAt(tv4, -2); tv1->computeAt(tv3, -1); tv2->computeAt(tv3, -2); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::rand({100}, options); at::Tensor cg_output_tv4 = at::empty_like(t0, options); fe.runFusion({t0}, {cg_output_tv4}); auto t1 = t0 + 1; auto t2 = t0 + 2; auto t3 = t1 + t2; auto t4 = t3 + 4; TORCH_CHECK( t4.allclose(cg_output_tv4), "tv4 error of: ", t4.sub(cg_output_tv4).abs().max()); } void testGPU_FusionTraversalOrder7() { Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeDummyTensor(1); fusion.addInput(tv0); TensorView* tv1 = add(tv0, new Float(1)); TensorView* tv2 = add(tv1, new Float(2)); TensorView* tv3 = add(tv0, new Float(3)); TensorView* tv4 = add(tv3, new Float(4)); TensorView* tv5 = add(tv2, tv4); fusion.addOutput(tv5); TensorView* tvs[] = {tv1, tv2, tv3, tv4, tv5}; for (auto tv : tvs) { tv->split(0, 2); tv->split(0, 4); tv->split(0, 8); } // computeAt into inner loop nests tv1->computeAt(tv2, -1); tv3->computeAt(tv4, -2); tv2->computeAt(tv5, -4); tv4->computeAt(tv5, -3); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::rand({100}, options); at::Tensor cg_output_tv5 = at::empty_like(t0, options); fe.runFusion({t0}, {cg_output_tv5}); auto t1 = t0 + 1; auto t2 = t1 + 2; auto t3 = t0 + 3; auto t4 = t3 + 4; auto t5 = t2 + t4; TORCH_CHECK( t5.allclose(cg_output_tv5), "tv5 error of: ", t5.sub(cg_output_tv5).abs().max()); } // Test predication of grid reduction void testGPU_FusionThreadPredicate() { const int gdimx = 4; const int bdimx = 128; Fusion fusion; FusionGuard fg(&fusion); TensorView* tv0 = makeDummyTensor(2); fusion.addInput(tv0); TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0); TensorView* tv2 = unaryOp(UnaryOpType::Neg, tv1); TensorView* tv3 = add(tv0, new Float(2)); fusion.addOutput(tv3); fusion.addOutput(tv2); tv1->split(1, bdimx); tv1->split(1, gdimx); tv3->split(1, bdimx); tv3->split(1, gdimx); TensorView* tv1_rf = tv1->rFactor({1}); tv1->computeAt(tv2, -1); tv1->axis(0)->parallelize(ParallelType::BIDy); tv1_rf->axis(0)->parallelize(ParallelType::BIDy); tv2->axis(0)->parallelize(ParallelType::BIDy); tv1->axis(-2)->parallelize(ParallelType::BIDx); tv1_rf->axis(-2)->parallelize(ParallelType::BIDx); tv1->axis(-1)->parallelize(ParallelType::TIDx); tv1_rf->axis(-1)->parallelize(ParallelType::TIDx); tv3->axis(3)->parallelize(ParallelType::TIDx); tv3->axis(2)->parallelize(ParallelType::BIDx); tv3->axis(0)->parallelize(ParallelType::BIDy); int numel_x = 100; int numel_y = 1000; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({numel_x, numel_y}, options); at::Tensor cg_output_tv2 = at::empty({numel_x}, options); at::Tensor cg_output_tv3 = at::empty_like(input, options); torch::jit::fuser::cuda::FusionExecutor fe; fe.compileFusion(&fusion); fe.runFusion({input}, {cg_output_tv3, cg_output_tv2}); auto aten_output_tv2 = -input.sum({1}); TORCH_CHECK(aten_output_tv2.allclose(cg_output_tv2)); auto aten_output_tv3 = input + 2.0; TORCH_CHECK(aten_output_tv3.allclose(cg_output_tv3)); } } // namespace jit } // namespace torch #endif // #if defined(USE_CUDA)