Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypeError: dot() got an unexpected keyword argument 'trans_b' #1098

Closed
conceptofmind opened this issue Jan 25, 2023 · 19 comments
Closed

TypeError: dot() got an unexpected keyword argument 'trans_b' #1098

conceptofmind opened this issue Jan 25, 2023 · 19 comments

Comments

@conceptofmind
Copy link

Hi all,

I am working on integrating the Triton version of Flash Attention in a GPT-like model.

For some reason I am receiving this error: TypeError: dot() got an unexpected keyword argument 'trans_b'

Here is a snippet of the code where the error is occurring:

import math
from dataclasses import dataclass, asdict

import torch
import torch.nn as nn
import flash_attn.flash_attn_triton as flash_attn_triton

from einops import rearrange, repeat

@dataclass
class GPT2Config:
    num_heads = 8
    head_dim = 64
    hidden_dim = 512
    attn_pdrop = 0.1
    resid_pdrop = 0.1

class NewGELUActivation(nn.Module):
    """
    Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
    the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
    """

    # noinspection PyMethodMayBeStatic
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return (
            0.5 * x * (1.0 + torch.tanh(
                math.sqrt(2.0 / math.pi)
                * (x + 0.044715 * torch.pow(x, 3.0))
            ))
        )

class Conv1D(nn.Module):
    """
    1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
    Basically works like a linear layer but the weights are transposed.
    Args:
        nf (`int`): The number of output features.
        nx (`int`): The number of input features.
    """

    def __init__(self, nf, nx):
        super().__init__()
        self.nf = nf
        w = torch.empty(nx, nf)
        nn.init.normal_(w, std=0.02)
        self.weight = nn.Parameter(w)
        self.bias = nn.Parameter(torch.zeros(nf))

    def forward(self, x):

        size_out = x.size()[:-1] + (self.nf,)
        x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
        x = x.view(size_out)

        return x

class GPT2Attention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        inner_dim = config.num_heads * config.head_dim
        self.c_attn = Conv1D(3 * inner_dim, config.hidden_dim)
        self.c_proj = Conv1D(config.hidden_dim, inner_dim)

        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)

    def forward(self, x):
        batch_size, seq_len, hidden_dim = x.shape
        num_heads, head_dim = self.config.num_heads, self.config.head_dim

        # x.shape -> torch.Size([1, 512, 512])

        qkv = self.c_attn(x).chunk(3, dim=-1)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h = num_heads), qkv)

        # batch_size, seq_len, num_heads, head_dim
        # q.shape, k.shape, v.shape -> torch.Size([1, 512, 8, 64])

        flash_attn_out = flash_attn_triton.flash_attn_func(
            q, 
            k, 
            v,
            None,
            True,
            1.0
        )

        out = flash_attn_out.contiguous().view(
            batch_size, seq_len, hidden_dim
            )

        attn_out = self.c_proj(out)

        attn_out = self.resid_dropout(attn_out)

        return attn_out

# Test GPT2Attention
config = GPT2Config()

attention = GPT2Attention(config).to(torch.float16).cuda()

print(attention(torch.randn(1, 512, 512).to(torch.float16).cuda()))

Any help would be appreciated.

Thank you,

Enrico

@ptillet
Copy link
Collaborator

ptillet commented Jan 25, 2023

On the latest master branch, you can use tl.trans(b) explicitly. See https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py#L165 for example

@conceptofmind
Copy link
Author

I will update the source code for the Triton Flash Attention implementation with your recommendation. And report back.

Thank you for the advice.

Enrico

@conceptofmind
Copy link
Author

@ptillet I tried both the most updated stable version of Triton and the nightly build available. I am using Python 3.9 and CUDA 11.3. I am now receiving the error AttributeError: module 'triton.language' has no attribute 'trans' at:

        qk += tl.dot(q, tl.trans(k))

Thank you,

Enrico

@ptillet
Copy link
Collaborator

ptillet commented Jan 25, 2023

you should install from source. The nightly build is about a month and half old, as new nightly builds are way less stable than the old one. We plan a triton alpha release in ~2 weeks.

@conceptofmind
Copy link
Author

Ok. I will try installing from source now.

@vedantroy
Copy link
Contributor

@conceptofmind , let me know if you get building from source to work. So far, I haven't managed to get it to work, which could 100% be an issue on my end.

@conceptofmind
Copy link
Author

conceptofmind commented Jan 25, 2023

@vedantroy I have not had any success building from source either. Throws an error stating it is missing scikit build. I installed pip3 install scikit-build but it is still occurring.

@ptillet
Copy link
Collaborator

ptillet commented Jan 25, 2023

What error are you guys having? The build should be quite smooth.

@conceptofmind
Copy link
Author

conceptofmind commented Jan 25, 2023

git clone https://github.com/openai/triton.git;
cd triton/python;
pip install cmake; # build time dependency
pip install -e .

Here is where it breaks:

/home/henry/triton/lib/Target/LLVMIR/LLVMIRTranslation.cpp:21:10: fatal error: filesystem: No such file or directory
     #include <filesystem>
              ^~~~~~~~~~~~
    compilation terminated.
    lib/Target/LLVMIR/CMakeFiles/obj.TritonLLVMIR.dir/build.make:78: recipe for target 'lib/Target/LLVMIR/CMakeFiles/obj.TritonLLVMIR.dir/LLVMIRTranslation.cpp.o' failed
    make[2]: *** [lib/Target/LLVMIR/CMakeFiles/obj.TritonLLVMIR.dir/LLVMIRTranslation.cpp.o] Error 1
    make[2]: Leaving directory '/home/henry/triton/python/build/temp.linux-x86_64-3.9'
    CMakeFiles/Makefile2:2755: recipe for target 'lib/Target/LLVMIR/CMakeFiles/obj.TritonLLVMIR.dir/all' failed
    make[1]: *** [lib/Target/LLVMIR/CMakeFiles/obj.TritonLLVMIR.dir/all] Error 2
    make[1]: Leaving directory '/home/henry/triton/python/build/temp.linux-x86_64-3.9'
    Makefile:148: recipe for target 'all' failed
    make: *** [all] Error 2
    Traceback (most recent call last):
      File "<string>", line 2, in <module>
      File "<pip-setuptools-caller>", line 34, in <module>
      File "/home/henry/triton/python/setup.py", line 171, in <module>
        setup(
      File "/home/henry/anaconda3/envs/palm/lib/python3.9/site-packages/setuptools/__init__.py", line 153, in setup
        return distutils.core.setup(**attrs)
      File "/home/henry/anaconda3/envs/palm/lib/python3.9/distutils/core.py", line 148, in setup
        dist.run_commands()
      File "/home/henry/anaconda3/envs/palm/lib/python3.9/distutils/dist.py", line 966, in run_commands
        self.run_command(cmd)
      File "/home/henry/anaconda3/envs/palm/lib/python3.9/distutils/dist.py", line 985, in run_command
        cmd_obj.run()
      File "/home/henry/anaconda3/envs/palm/lib/python3.9/site-packages/setuptools/command/develop.py", line 34, in run
        self.install_for_development()
      File "/home/henry/anaconda3/envs/palm/lib/python3.9/site-packages/setuptools/command/develop.py", line 114, in install_for_development
        self.run_command('build_ext')
      File "/home/henry/anaconda3/envs/palm/lib/python3.9/distutils/cmd.py", line 313, in run_command
        self.distribution.run_command(command)
      File "/home/henry/anaconda3/envs/palm/lib/python3.9/distutils/dist.py", line 985, in run_command
        cmd_obj.run()
      File "/home/henry/triton/python/setup.py", line 126, in run
        self.build_extension(ext)
      File "/home/henry/triton/python/setup.py", line 168, in build_extension
        subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=self.build_temp)
      File "/home/henry/anaconda3/envs/palm/lib/python3.9/subprocess.py", line 373, in check_call
        raise CalledProcessError(retcode, cmd)
    subprocess.CalledProcessError: Command '['cmake', '--build', '.', '--config', 'TritonRelBuildWithAsserts', '--', '-j24']' returned non-zero exit status 2.
    [end of output]

@conceptofmind
Copy link
Author

Obtaining file:///home/henry/triton/python
Collecting cmake (from triton==2.0.0)
  Downloading https://files.pythonhosted.org/packages/e9/67/3d545c3a706bc427b1bd2a9108e3986dfc8c1450be0da8149fe591902875/cmake-3.25.0.tar.gz
    Complete output from command python setup.py egg_info:
    Traceback (most recent call last):
      File "<string>", line 1, in <module>
      File "/tmp/pip-build-7nv5_q__/cmake/setup.py", line 8, in <module>
        from skbuild import setup
    ModuleNotFoundError: No module named 'skbuild'
    
    ----------------------------------------

@vedantroy
Copy link
Contributor

@ptillet
Here's my build error: #1096
If you need more reproducibility, I can try to throw together a docker container that fails to build triton.

@ptillet
Copy link
Collaborator

ptillet commented Jan 26, 2023

@conceptofmind what compiler version are you using. It seems like your compiler doesn't support

@vedantroy it looks like the LLVM download was interrupted somehow during your install. Triton installer currently doesn't check md5 or integrity of the downloaded archive. Can you delete ~/.triton/cache and retry?

@conceptofmind
Copy link
Author

conceptofmind commented Jan 26, 2023

@ptillet I am pretty sure it is GCC on Ubuntu 18.04. I do not think it is MinGW. I will have to check the version when I get back. I was also having a problem with the installation constantly stopping. I would have to keep installing it to progress. It stopped working around 91%.

@vedantroy
Copy link
Contributor

vedantroy commented Jan 26, 2023

@conceptofmind @ptillet , I can reproduce conceptofmind's skbuild error with the following docker:

FROM nvidia/cuda:11.6.0-devel-ubuntu18.04 as dep_triton

ARG DEBIAN_FRONTEND=noninteractive
RUN apt update && apt install -y git software-properties-common
RUN apt install -y python3-pip
RUN git clone https://github.com/openai/triton.git \
    && cd triton/python \
    && git checkout d3e753b5c00bbae855b283adf3d3a5d6d1547830 \
    && python3 -m pip install cmake scikit-build \
    && python3 -m pip wheel --wheel-dir /tmp/dist . --verbose

I can avoid the above error by using this line:

FROM nvidia/cuda:11.6.1-devel-ubuntu20.04 as dep_triton

but then I run into this error:

#26 66.54   running bdist_wheel
#26 66.54   running build
#26 66.54   running build_py
#26 66.59   creating build
#26 66.59   creating build/lib.linux-x86_64-3.8
#26 66.59   creating build/lib.linux-x86_64-3.8/triton
#26 66.59   copying triton/__init__.py -> build/lib.linux-x86_64-3.8/triton
#26 66.59   copying triton/testing.py -> build/lib.linux-x86_64-3.8/triton
#26 66.59   copying triton/compiler.py -> build/lib.linux-x86_64-3.8/triton
#26 66.59   copying triton/utils.py -> build/lib.linux-x86_64-3.8/triton
#26 66.59   package init file 'triton/_C/__init__.py' not found (or not a regular file)
#26 66.59   creating build/lib.linux-x86_64-3.8/triton/language
#26 66.59   copying triton/language/core.py -> build/lib.linux-x86_64-3.8/triton/language
#26 66.59   copying triton/language/extern.py -> build/lib.linux-x86_64-3.8/triton/language
#26 66.59   copying triton/language/__init__.py -> build/lib.linux-x86_64-3.8/triton/language
#26 66.59   copying triton/language/random.py -> build/lib.linux-x86_64-3.8/triton/language
#26 66.59   copying triton/language/libdevice.py -> build/lib.linux-x86_64-3.8/triton/language
#26 66.59   copying triton/language/semantic.py -> build/lib.linux-x86_64-3.8/triton/language
#26 66.59   creating build/lib.linux-x86_64-3.8/triton/tools
#26 66.59   copying triton/tools/__init__.py -> build/lib.linux-x86_64-3.8/triton/tools
#26 66.59   copying triton/tools/aot.py -> build/lib.linux-x86_64-3.8/triton/tools
#26 66.59   copying triton/tools/disasm.py -> build/lib.linux-x86_64-3.8/triton/tools
#26 66.59   copying triton/tools/build_extern.py -> build/lib.linux-x86_64-3.8/triton/tools
#26 66.59   creating build/lib.linux-x86_64-3.8/triton/impl
#26 66.59   copying triton/impl/base.py -> build/lib.linux-x86_64-3.8/triton/impl
#26 66.59   copying triton/impl/__init__.py -> build/lib.linux-x86_64-3.8/triton/impl
#26 66.59   creating build/lib.linux-x86_64-3.8/triton/ops
#26 66.59   copying triton/ops/__init__.py -> build/lib.linux-x86_64-3.8/triton/ops
#26 66.60   copying triton/ops/cross_entropy.py -> build/lib.linux-x86_64-3.8/triton/ops
#26 66.60   copying triton/ops/matmul_perf_model.py -> build/lib.linux-x86_64-3.8/triton/ops
#26 66.60   copying triton/ops/matmul.py -> build/lib.linux-x86_64-3.8/triton/ops
#26 66.60   creating build/lib.linux-x86_64-3.8/triton/runtime
#26 66.60   copying triton/runtime/jit.py -> build/lib.linux-x86_64-3.8/triton/runtime
#26 66.60   copying triton/runtime/__init__.py -> build/lib.linux-x86_64-3.8/triton/runtime
#26 66.60   copying triton/runtime/autotuner.py -> build/lib.linux-x86_64-3.8/triton/runtime
#26 66.60   creating build/lib.linux-x86_64-3.8/triton/ops/blocksparse
#26 66.60   copying triton/ops/blocksparse/__init__.py -> build/lib.linux-x86_64-3.8/triton/ops/blocksparse
#26 66.60   copying triton/ops/blocksparse/softmax.py -> build/lib.linux-x86_64-3.8/triton/ops/blocksparse
#26 66.60   copying triton/ops/blocksparse/matmul.py -> build/lib.linux-x86_64-3.8/triton/ops/blocksparse
#26 66.60   running egg_info
#26 66.60   creating triton.egg-info
#26 66.60   writing triton.egg-info/PKG-INFO
#26 66.60   writing dependency_links to triton.egg-info/dependency_links.txt
#26 66.60   writing requirements to triton.egg-info/requires.txt
#26 66.60   writing top-level names to triton.egg-info/top_level.txt
#26 66.60   writing manifest file 'triton.egg-info/SOURCES.txt'
#26 66.60   reading manifest file 'triton.egg-info/SOURCES.txt'
#26 66.60   reading manifest template 'MANIFEST.in'
#26 66.60   writing manifest file 'triton.egg-info/SOURCES.txt'
#26 66.61   copying triton/language/libdevice.10.bc -> build/lib.linux-x86_64-3.8/triton/language
#26 66.61   running build_ext
#26 66.65   downloading and extracting https://github.com/pybind/pybind11/archive/refs/tags/v2.10.0.tar.gz ...
#26 67.26   downloading and extracting https://github.com/llvm/llvm-project/releases/download/llvmorg-14.0.0/clang+llvm-14.0.0-x86_64-linux-gnu-ubuntu-18.04.tar.xz ...
#26 136.9   CMake Warning:
#26 136.9     Ignoring extra path from command line:
#26 136.9 
#26 136.9      "/tmp"
#26 136.9 
#26 136.9 
#26 136.9   CMake Error: The source directory "/tmp" does not appear to contain CMakeLists.txt.
#26 136.9   Specify --help for usage, or press the help button on the CMake GUI.
#26 136.9   Traceback (most recent call last):
#26 136.9     File "<string>", line 1, in <module>
#26 136.9     File "/tmp/pip-req-build-ydsm2i5w/setup.py", line 171, in <module>
#26 136.9       setup(
#26 136.9     File "/usr/lib/python3/dist-packages/setuptools/__init__.py", line 144, in setup
#26 136.9       return distutils.core.setup(**attrs)
#26 136.9     File "/usr/lib/python3.8/distutils/core.py", line 148, in setup
#26 136.9       dist.run_commands()
#26 136.9     File "/usr/lib/python3.8/distutils/dist.py", line 966, in run_commands
#26 136.9       self.run_command(cmd)
#26 136.9     File "/usr/lib/python3.8/distutils/dist.py", line 985, in run_command
#26 136.9       cmd_obj.run()
#26 136.9     File "/usr/lib/python3/dist-packages/wheel/bdist_wheel.py", line 223, in run
#26 136.9       self.run_command('build')
#26 136.9     File "/usr/lib/python3.8/distutils/cmd.py", line 313, in run_command
#26 136.9       self.distribution.run_command(command)
#26 136.9     File "/usr/lib/python3.8/distutils/dist.py", line 985, in run_command
#26 136.9       cmd_obj.run()
#26 136.9     File "/usr/lib/python3.8/distutils/command/build.py", line 135, in run
#26 136.9       self.run_command(cmd_name)
#26 136.9     File "/usr/lib/python3.8/distutils/cmd.py", line 313, in run_command
#26 136.9       self.distribution.run_command(command)
#26 136.9     File "/usr/lib/python3.8/distutils/dist.py", line 985, in run_command
#26 136.9       cmd_obj.run()
#26 136.9     File "/tmp/pip-req-build-ydsm2i5w/setup.py", line 126, in run
#26 136.9       self.build_extension(ext)
#26 136.9     File "/tmp/pip-req-build-ydsm2i5w/setup.py", line 167, in build_extension
#26 136.9       subprocess.check_call(["cmake", self.base_dir] + cmake_args, cwd=self.build_temp, env=env)
#26 136.9     File "/usr/lib/python3.8/subprocess.py", line 364, in check_call
#26 136.9       raise CalledProcessError(retcode, cmd)
#26 136.9   subprocess.CalledProcessError: Command '['cmake', '/tmp', '-DLLVM_ENABLE_WERROR=ON', '-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=/tmp/pip-req-build-ydsm2i5w/build/lib.linux-x86_64-3.8/triton/_C', '-DTRITON_BUILD_TUTORIALS=OFF', '-DTRITON_BUILD_PYTHON_MODULE=ON', '-DPython3_EXECUTABLE:FILEPATH=/usr/bin/python3', '-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON', '-DPYTHON_INCLUDE_DIRS=/usr/include/python3.8', '-DPYBIND11_INCLUDE_DIR=/root/.triton/pybind11/pybind11-2.10.0/include', '-DLLVM_INCLUDE_DIRS=/root/.triton/llvm/clang+llvm-14.0.0-x86_64-linux-gnu-ubuntu-18.04/include', '-DLLVM_LIBRARY_DIR=/root/.triton/llvm/clang+llvm-14.0.0-x86_64-linux-gnu-ubuntu-18.04/lib', '-DCMAKE_BUILD_TYPE=TritonRelBuildWithAsserts']' returned non-zero exit status 1.
#26 137.0   ERROR: Failed building wheel for triton
#26 137.0   Running command /usr/bin/python3 -u -c 'import sys, setuptools, tokenize; sys.argv[0] = '"'"'/tmp/pip-req-build-ydsm2i5w/setup.py'"'"'; __file__='"'"'/tmp/pip-req-build-ydsm2i5w/setup.py'"'"';f=getattr(tokenize, '"'"'open'"'"', open)(__file__);code=f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, __file__, '"'"'exec'"'"'))' clean --all
#26 137.1   running clean
#26 137.1   removing 'build/temp.linux-x86_64-3.8' (and everything under it)
#26 137.1   removing 'build/lib.linux-x86_64-3.8' (and everything under it)
#26 137.1   'build/bdist.linux-x86_64' does not exist -- can't clean it
#26 137.1   'build/scripts-3.8' does not exist -- can't clean it
#26 137.1   removing 'build'
#26 137.1   Running command /usr/bin/python3 -u -c 'import sys, setuptools, tokenize; sys.argv[0] = '"'"'/tmp/pip-wheel-zp2et7ds/lit/setup.py'"'"'; __file__='"'"'/tmp/pip-wheel-zp2et7ds/lit/setup.py'"'"';f=getattr(tokenize, '"'"'open'"'"', open)(__file__);code=f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, __file__, '"'"'exec'"'"'))' bdist_wheel -d /tmp/pip-wheel-7epb6eg8
#26 137.3   running bdist_wheel
#26 137.3   running build
#26 137.3   running build_py
#26 137.3   creating build
#26 137.3   creating build/lib
#26 137.3   creating build/lib/lit
#26 137.3   copying lit/TestTimes.py -> build/lib/lit
#26 137.3   copying lit/cl_arguments.py -> build/lib/lit
#26 137.3   copying lit/__init__.py -> build/lib/lit
#26 137.3   copying lit/TestRunner.py -> build/lib/lit
#26 137.3   copying lit/TestingConfig.py -> build/lib/lit
#26 137.3   copying lit/display.py -> build/lib/lit
#26 137.3   copying lit/LitConfig.py -> build/lib/lit
#26 137.3   copying lit/ShCommands.py -> build/lib/lit
#26 137.3   copying lit/main.py -> build/lib/lit
#26 137.3   copying lit/BooleanExpression.py -> build/lib/lit
#26 137.3   copying lit/discovery.py -> build/lib/lit
#26 137.3   copying lit/LitTestCase.py -> build/lib/lit
#26 137.3   copying lit/run.py -> build/lib/lit
#26 137.3   copying lit/util.py -> build/lib/lit
#26 137.3   copying lit/reports.py -> build/lib/lit
#26 137.3   copying lit/Test.py -> build/lib/lit
#26 137.3   copying lit/ProgressBar.py -> build/lib/lit
#26 137.3   copying lit/ShUtil.py -> build/lib/lit
#26 137.3   copying lit/worker.py -> build/lib/lit
#26 137.3   creating build/lib/lit/builtin_commands
#26 137.3   copying lit/builtin_commands/diff.py -> build/lib/lit/builtin_commands
#26 137.3   copying lit/builtin_commands/__init__.py -> build/lib/lit/builtin_commands
#26 137.3   copying lit/builtin_commands/cat.py -> build/lib/lit/builtin_commands
#26 137.3   creating build/lib/lit/formats
#26 137.3   copying lit/formats/base.py -> build/lib/lit/formats
#26 137.3   copying lit/formats/__init__.py -> build/lib/lit/formats
#26 137.3   copying lit/formats/shtest.py -> build/lib/lit/formats
#26 137.3   copying lit/formats/googletest.py -> build/lib/lit/formats
#26 137.3   creating build/lib/lit/llvm
#26 137.3   copying lit/llvm/config.py -> build/lib/lit/llvm
#26 137.3   copying lit/llvm/__init__.py -> build/lib/lit/llvm
#26 137.3   copying lit/llvm/subst.py -> build/lib/lit/llvm
#26 137.3   installing to build/bdist.linux-x86_64/wheel
#26 137.3   running install
#26 137.3   running install_lib
#26 137.3   creating build/bdist.linux-x86_64
#26 137.3   creating build/bdist.linux-x86_64/wheel
#26 137.3   creating build/bdist.linux-x86_64/wheel/lit
#26 137.3   copying build/lib/lit/TestTimes.py -> build/bdist.linux-x86_64/wheel/lit
#26 137.3   creating build/bdist.linux-x86_64/wheel/lit/builtin_commands
#26 137.3   copying build/lib/lit/builtin_commands/diff.py -> build/bdist.linux-x86_64/wheel/lit/builtin_commands
#26 137.3   copying build/lib/lit/builtin_commands/__init__.py -> build/bdist.linux-x86_64/wheel/lit/builtin_commands
#26 137.3   copying build/lib/lit/builtin_commands/cat.py -> build/bdist.linux-x86_64/wheel/lit/builtin_commands
#26 137.3   copying build/lib/lit/cl_arguments.py -> build/bdist.linux-x86_64/wheel/lit
#26 137.3   copying build/lib/lit/__init__.py -> build/bdist.linux-x86_64/wheel/lit
#26 137.3   copying build/lib/lit/TestRunner.py -> build/bdist.linux-x86_64/wheel/lit
#26 137.3   copying build/lib/lit/TestingConfig.py -> build/bdist.linux-x86_64/wheel/lit
#26 137.3   copying build/lib/lit/display.py -> build/bdist.linux-x86_64/wheel/lit
#26 137.3   copying build/lib/lit/LitConfig.py -> build/bdist.linux-x86_64/wheel/lit
#26 137.3   creating build/bdist.linux-x86_64/wheel/lit/formats
#26 137.3   copying build/lib/lit/formats/base.py -> build/bdist.linux-x86_64/wheel/lit/formats
#26 137.3   copying build/lib/lit/formats/__init__.py -> build/bdist.linux-x86_64/wheel/lit/formats
#26 137.3   copying build/lib/lit/formats/shtest.py -> build/bdist.linux-x86_64/wheel/lit/formats
#26 137.3   copying build/lib/lit/formats/googletest.py -> build/bdist.linux-x86_64/wheel/lit/formats
#26 137.3   copying build/lib/lit/ShCommands.py -> build/bdist.linux-x86_64/wheel/lit
#26 137.3   copying build/lib/lit/main.py -> build/bdist.linux-x86_64/wheel/lit
#26 137.3   copying build/lib/lit/BooleanExpression.py -> build/bdist.linux-x86_64/wheel/lit
#26 137.3   copying build/lib/lit/discovery.py -> build/bdist.linux-x86_64/wheel/lit
#26 137.3   copying build/lib/lit/LitTestCase.py -> build/bdist.linux-x86_64/wheel/lit
#26 137.3   copying build/lib/lit/run.py -> build/bdist.linux-x86_64/wheel/lit
#26 137.3   creating build/bdist.linux-x86_64/wheel/lit/llvm
#26 137.3   copying build/lib/lit/llvm/config.py -> build/bdist.linux-x86_64/wheel/lit/llvm
#26 137.3   copying build/lib/lit/llvm/__init__.py -> build/bdist.linux-x86_64/wheel/lit/llvm
#26 137.3   copying build/lib/lit/llvm/subst.py -> build/bdist.linux-x86_64/wheel/lit/llvm
#26 137.3   copying build/lib/lit/util.py -> build/bdist.linux-x86_64/wheel/lit
#26 137.3   copying build/lib/lit/reports.py -> build/bdist.linux-x86_64/wheel/lit
#26 137.3   copying build/lib/lit/Test.py -> build/bdist.linux-x86_64/wheel/lit
#26 137.3   copying build/lib/lit/ProgressBar.py -> build/bdist.linux-x86_64/wheel/lit
#26 137.3   copying build/lib/lit/ShUtil.py -> build/bdist.linux-x86_64/wheel/lit
#26 137.3   copying build/lib/lit/worker.py -> build/bdist.linux-x86_64/wheel/lit
#26 137.3   running install_egg_info
#26 137.3   running egg_info
#26 137.3   writing lit.egg-info/PKG-INFO
#26 137.3   writing dependency_links to lit.egg-info/dependency_links.txt
#26 137.3   writing entry points to lit.egg-info/entry_points.txt
#26 137.3   writing top-level names to lit.egg-info/top_level.txt
#26 137.3   reading manifest file 'lit.egg-info/SOURCES.txt'
#26 137.3   reading manifest template 'MANIFEST.in'
#26 137.3   warning: no files found matching 'TODO'
#26 137.3   warning: no previously-included files matching '*pyc' found anywhere in distribution
#26 137.3   warning: no previously-included files matching '*~' found anywhere in distribution
#26 137.3   no previously-included directories found matching 'tests/Output'
#26 137.3   no previously-included directories found matching 'tests/*/Output'
#26 137.3   no previously-included directories found matching 'tests/*/*/Output'
#26 137.3   no previously-included directories found matching 'tests/*/*/*/Output'
#26 137.3   writing manifest file 'lit.egg-info/SOURCES.txt'
#26 137.3   Copying lit.egg-info to build/bdist.linux-x86_64/wheel/lit-15.0.7.egg-info
#26 137.3   running install_scripts
#26 137.4   adding license file "LICENSE.TXT" (matched pattern "LICEN[CS]E*")
#26 137.4   creating build/bdist.linux-x86_64/wheel/lit-15.0.7.dist-info/WHEEL
#26 137.4   creating '/tmp/pip-wheel-7epb6eg8/lit-15.0.7-py3-none-any.whl' and adding 'build/bdist.linux-x86_64/wheel' to it
#26 137.4   adding 'lit/BooleanExpression.py'
#26 137.4   adding 'lit/LitConfig.py'
#26 137.4   adding 'lit/LitTestCase.py'
#26 137.4   adding 'lit/ProgressBar.py'
#26 137.4   adding 'lit/ShCommands.py'
#26 137.4   adding 'lit/ShUtil.py'
#26 137.4   adding 'lit/Test.py'
#26 137.4   adding 'lit/TestRunner.py'
#26 137.4   adding 'lit/TestTimes.py'
#26 137.4   adding 'lit/TestingConfig.py'
#26 137.4   adding 'lit/__init__.py'
#26 137.4   adding 'lit/cl_arguments.py'
#26 137.4   adding 'lit/discovery.py'
#26 137.4   adding 'lit/display.py'
#26 137.4   adding 'lit/main.py'
#26 137.4   adding 'lit/reports.py'
#26 137.4   adding 'lit/run.py'
#26 137.4   adding 'lit/util.py'
#26 137.4   adding 'lit/worker.py'
#26 137.4   adding 'lit/builtin_commands/__init__.py'
#26 137.4   adding 'lit/builtin_commands/cat.py'
#26 137.4   adding 'lit/builtin_commands/diff.py'
#26 137.4   adding 'lit/formats/__init__.py'
#26 137.4   adding 'lit/formats/base.py'
#26 137.4   adding 'lit/formats/googletest.py'
#26 137.4   adding 'lit/formats/shtest.py'
#26 137.4   adding 'lit/llvm/__init__.py'
#26 137.4   adding 'lit/llvm/config.py'
#26 137.4   adding 'lit/llvm/subst.py'
#26 137.4   adding 'lit-15.0.7.dist-info/LICENSE.TXT'
#26 137.4   adding 'lit-15.0.7.dist-info/METADATA'
#26 137.4   adding 'lit-15.0.7.dist-info/WHEEL'
#26 137.4   adding 'lit-15.0.7.dist-info/entry_points.txt'
#26 137.4   adding 'lit-15.0.7.dist-info/top_level.txt'
#26 137.4   adding 'lit-15.0.7.dist-info/RECORD'
#26 137.4   removing build/bdist.linux-x86_64/wheel
#26 138.1 ERROR: Failed to build one or more wheels
#26 ERROR: executor failed running [/bin/sh -c git clone https://github.com/openai/triton.git     && cd triton/python     && git checkout d3e753b5c00bbae855b283adf3d3a5d6d1547830     && python3 -m pip install cmake scikit-build     && python3 -m pip wheel --wheel-dir /tmp/dist . --verbose]: exit code: 1
------
 > [dep_triton 4/4] RUN git clone https://github.com/openai/triton.git     && cd triton/python     && git checkout d3e753b5c00bbae855b283adf3d3a5d6d1547830     && python3 -m pip install cmake scikit-build     && python3 -m pip wheel --wheel-dir /tmp/dist . --verbose:
#26 137.4   adding 'lit/llvm/config.py'
#26 137.4   adding 'lit/llvm/subst.py'
#26 137.4   adding 'lit-15.0.7.dist-info/LICENSE.TXT'
#26 137.4   adding 'lit-15.0.7.dist-info/METADATA'
#26 137.4   adding 'lit-15.0.7.dist-info/WHEEL'
#26 137.4   adding 'lit-15.0.7.dist-info/entry_points.txt'
#26 137.4   adding 'lit-15.0.7.dist-info/top_level.txt'
#26 137.4   adding 'lit-15.0.7.dist-info/RECORD'
#26 137.4   removing build/bdist.linux-x86_64/wheel
#26 138.1 ERROR: Failed to build one or more wheels
------
failed to solve: executor failed running [/bin/sh -c git clone https://github.com/openai/triton.git     && cd triton/python     && git checkout d3e753b5c00bbae855b283adf3d3a5d6d1547830     && python3 -m pip install cmake scikit-build     && python3 -m pip wheel --wheel-dir /tmp/dist . --verbose]: exit code: 1

I know I can build Triton with this Dockerfile:

FROM nvidia/cuda:12.0.0-devel-ubuntu22.04

ARG DEBIAN_FRONTEND=noninteractive
RUN apt update && apt install -y git software-properties-common
RUN apt install -y python3-pip
# RUN add-apt-repository ppa:deadsnakes/ppa && apt update
# RUN apt install -y python3.8
RUN git clone https://github.com/openai/triton.git \
    && cd triton/python \
    && rm -rf ~/.triton \
    && python3 -m pip install cmake \
    && python3 -m pip wheel --wheel-dir /tmp/dist . --verbose

# RUN python3 -c "from triton.language import full"

so maybe Triton just does not build on old versions of Ubuntu for some reason.
One question I have is, does Triton require CUDA to build? I'm currently using a hack in my Dockerfile where I build Triton with Ubuntu 22 + CUDA 12 as a wheel, and then I install it in a different docker image which is running Ubuntu 18 + CUDA 11. Not sure if this will cause issues due to mismatched CUDA versions.

@ptillet
Copy link
Collaborator

ptillet commented Jan 26, 2023

Yeah, Ubuntu 18.04 includes GCC 7.5.0 by default, which doesn't support C++17. You can still use Ubuntu 18.04 but you have to install a more recent version of gcc/clang. As for the other errors, we'll look into it when we have some time, we're still firefighting some issues with the new backend.

@ptillet
Copy link
Collaborator

ptillet commented Jan 26, 2023

And FYI wheel built on Ubuntu 22 + CUDA 12 should work everywhere. Triton isn't tied to any CUDA version.

@conceptofmind
Copy link
Author

I will update my gcc/clang compiler later and let you know if the issue is resolved.

@conceptofmind
Copy link
Author

@ptillet Works now with gcc/g++ 9. Was able to install from the source and run the Triton version of Flash Attention.

@StePoli-00
Copy link

StePoli-00 commented Mar 11, 2024

I solved the issue uninstall it triton, installing and again uninstall it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants