Skip to content
Snippets Groups Projects
Commit 1baa2780 authored by Jayesh's avatar Jayesh
Browse files

cnn model design

parent 19c5fbba
No related branches found
No related tags found
No related merge requests found
%% Cell type:code id: tags:
``` python
import keras_tuner as kt
import numpy as np
import tensorflow as tf
```
%% Cell type:code id: tags:
``` python
from constants import *
```
%% Cell type:code id: tags:
``` python
class StackOverflowCNN:
"""
A class representing a CNN model for Stack Overflow classification.
"""
def __init__(self, vocab_length, embedding_matrix, filters=64, hidden_layer_size=0,
num_convolutions=3, kernel_size=3):
"""
Initializes the StackOverflowCNN model.
Args:
vocab_length (int): The length of the vocabulary.
embedding_matrix (numpy.ndarray): The embedding matrix.
filters (int, optional): The number of filters in the convolutional layers. Defaults to 64.
hidden_layer_size (int, optional): The size of the hidden layer. Defaults to 0.
num_convolutions (int, optional): The number of convolutional layers. Defaults to 3.
kernel_size (int, optional): The size of the convolutional kernel. Defaults to 3.
"""
input_layer = tf.keras.layers.Input(
shape=(MAX_SEQUENCE_LENGTH,),
batch_size=32
)
embedding = tf.keras.layers.Embedding(vocab_length,
EMBEDDING_DIM,
input_length=MAX_SEQUENCE_LENGTH,
embeddings_initializer=tf.keras.initializers.Constant(embedding_matrix),
trainable=True, mask_zero=True)(input_layer)
conv_layers = [
tf.keras.layers.Conv1D(filters=filters,
kernel_size=kernel_size)(embedding)
for _ in range(num_convolutions)
]
pooling_layers = [
tf.keras.layers.GlobalMaxPooling1D()(conv)
for conv in conv_layers
]
if len(pooling_layers) == 1:
hidden = pooling_layers[0]
else:
hidden = tf.keras.layers.concatenate(pooling_layers, axis=1)
if hidden_layer_size > 0:
hidden = tf.keras.layers.Dense(hidden_layer_size, activation='relu')(hidden)
outputs = tf.keras.layers.Dense(NUM_CLASSES)(hidden)
self.model = tf.keras.Model(inputs=[embedding], outputs=outputs)
def get_model(self):
"""
Returns the CNN model.
Returns:
tf.keras.Model: The CNN model.
"""
return self.model
```
%% Cell type:code id: tags:
``` python
```
MAX_SEQUENCE_LENGTH = 200
EMBEDDING_DIM = 100
NUM_CLASSES = 3
\ No newline at end of file
absl-py==2.1.0
annotated-types==0.6.0
anyio==3.7.1
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens==2.4.0
astunparse==1.6.3
async-lru==2.0.4
attrs==23.2.0
Babel==2.15.0
backcall==0.2.0
backports.tarfile==1.1.1
beautifulsoup4==4.12.3
bleach==6.1.0
certifi==2023.7.22
cffi==1.16.0
charset-normalizer==3.3.2
click==8.1.7
colorama==0.4.6
comm==0.1.4
contourpy==1.1.1
cycler==0.12.1
debugpy==1.8.0
decorator==5.1.1
defusedxml==0.7.1
dnspython==2.4.2
email-validator==2.0.0.post2
et-xmlfile==1.1.0
exceptiongroup==1.1.3
executing==2.0.0
fastapi==0.103.2
fastjsonschema==2.19.1
flatbuffers==24.3.25
fonttools==4.43.1
fqdn==1.5.1
funcy==2.0
gast==0.5.4
gensim==4.3.2
google-pasta==0.2.0
grpcio==1.64.1
h11==0.14.0
h5py==3.11.0
httpcore==0.18.0
httptools==0.6.1
httpx==0.25.0
idna==3.4
importlib_metadata==7.1.0
ipykernel==6.25.2
ipython==8.16.1
ipywidgets==8.1.3
isoduration==20.11.0
itsdangerous==2.1.2
jaraco.classes==3.4.0
jaraco.context==5.3.0
jaraco.functools==4.0.1
jedi==0.19.1
Jinja2==3.1.2
jira==3.8.0
joblib==1.4.2
json5==0.9.25
jsonpointer==2.4
jsonschema==4.22.0
jsonschema-specifications==2023.12.1
jupyter==1.0.0
jupyter-console==6.6.3
jupyter-events==0.10.0
jupyter-lsp==2.2.5
jupyter_client==8.4.0
jupyter_core==5.4.0
jupyter_server==2.14.1
jupyter_server_terminals==0.5.3
jupyterlab==4.2.1
jupyterlab_pygments==0.3.0
jupyterlab_server==2.27.2
jupyterlab_widgets==3.0.11
keras==3.3.3
keras-tuner==1.4.7
keyring==25.2.1
kiwisolver==1.4.5
kt-legacy==1.0.5
libclang==18.1.1
Markdown==3.6
markdown-it-py==3.0.0
MarkupSafe==2.1.3
matplotlib==3.8.0
matplotlib-inline==0.1.6
mdurl==0.1.2
mistune==3.0.2
ml-dtypes==0.3.2
more-itertools==10.2.0
namex==0.0.8
nbclient==0.10.0
nbconvert==7.16.4
nbformat==5.10.4
nest-asyncio==1.5.8
nltk==3.8.1
notebook==7.2.1
notebook_shim==0.2.4
numexpr==2.10.0
numpy==1.26.1
oauthlib==3.2.2
openpyxl==3.1.2
opt-einsum==3.3.0
optree==0.11.0
orjson==3.9.9
overrides==7.7.0
packaging==23.2
pandas==2.1.1
pandocfilters==1.5.1
parso==0.8.3
patsy==0.5.6
pickleshare==0.7.5
Pillow==10.1.0
platformdirs==3.11.0
plotly==5.22.0
prometheus_client==0.20.0
prompt-toolkit==3.0.39
protobuf==4.25.3
psutil==5.9.6
pure-eval==0.2.2
pycparser==2.22
pydantic==2.4.2
pydantic-extra-types==2.1.0
pydantic-settings==2.0.3
pydantic_core==2.10.1
Pygments==2.16.1
pyLDAvis==3.4.1
pyparsing==3.1.1
python-dateutil==2.8.2
python-dotenv==1.0.0
python-json-logger==2.0.7
python-multipart==0.0.6
pytz==2023.3.post1
pywin32==306
pywin32-ctypes==0.2.2
pywinpty==2.0.13
PyYAML==6.0.1
pyzmq==25.1.1
qtconsole==5.5.2
QtPy==2.4.1
referencing==0.35.1
regex==2024.5.15
requests==2.32.2
requests-oauthlib==2.0.0
requests-toolbelt==1.0.0
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rich==13.7.1
rpds-py==0.18.1
scikit-learn==1.5.0
scikit-posthocs==0.9.0
scipy==1.12.0
seaborn==0.13.2
Send2Trash==1.8.3
six==1.16.0
smart-open==7.0.4
sniffio==1.3.0
soupsieve==2.5
stack-data==0.6.3
starlette==0.27.0
statsmodels==0.14.2
tenacity==8.3.0
tensorboard==2.16.2
tensorboard-data-server==0.7.2
tensorflow==2.16.1
tensorflow-intel==2.16.1
tensorflow-io-gcs-filesystem==0.31.0
termcolor==2.4.0
terminado==0.18.1
threadpoolctl==3.5.0
tinycss2==1.3.0
tomli==2.0.1
tornado==6.3.3
tqdm==4.66.4
traitlets==5.11.2
types-python-dateutil==2.9.0.20240316
typing_extensions==4.8.0
tzdata==2023.3
ujson==5.8.0
uri-template==1.3.0
urllib3==2.2.1
uvicorn==0.23.2
watchfiles==0.21.0
wcwidth==0.2.8
webcolors==24.6.0
webencodings==0.5.1
websocket-client==1.8.0
websockets==11.0.3
Werkzeug==3.0.3
widgetsnbextension==4.0.11
wrapt==1.16.0
zipp==3.19.0
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment