【发布时间】:2020-08-17 16:52:33
【问题描述】:
我想利用我在面向数据科学中找到的有前途的神经网络进行案例研究。
我拥有的数据形状是:
X_train:(1200,18,15)
y_train:(1200,18,1)
这里是 NN,其中包含 GRU、Flatten 和 Dense 等层。
def twds_model(layer1=32, layer2=32, layer3=16, dropout_rate=0.5, optimizer='Adam'
, learning_rate=0.001, activation='relu', loss='mse'):
model = Sequential()
model.add(Bidirectional(GRU(layer1, return_sequences=True),input_shape=(X_train.shape[1],X_train.shape[2])))
model.add(AveragePooling1D(2))
model.add(Conv1D(layer2, 3, activation=activation, padding='same',
name='extractor'))
model.add(Flatten())
model.add(Dense(layer3,activation=activation))
model.add(Dropout(dropout_rate))
model.add(Dense(1))
model.compile(optimizer=optimizer,loss=loss)
return model
twds_model=twds_model()
print(twds_model.summary())
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
bidirectional_4 (Bidirection (None, 18, 64) 9216
_________________________________________________________________
average_pooling1d_1 (Average (None, 9, 64) 0
_________________________________________________________________
extractor (Conv1D) (None, 9, 32) 6176
_________________________________________________________________
flatten_1 (Flatten) (None, 288) 0
_________________________________________________________________
dense_3 (Dense) (None, 16) 4624
_________________________________________________________________
dropout_4 (Dropout) (None, 16) 0
_________________________________________________________________
dense_4 (Dense) (None, 1) 17
=================================================================
Total params: 20,033
Trainable params: 20,033
Non-trainable params: 0
_________________________________________________________________
None
不幸的是,我陷入了一种矛盾的错误陷阱,输入和输出的形状不匹配。这里是上层情况下的错误。
InvalidArgumentError: Incompatible shapes: [144,1] vs. [144,18,1]
[[{{node loss_2/dense_4_loss/sub}}]]
[[{{node loss_2/mul}}]]
Train on 10420 samples, validate on 1697 samples
Epoch 1/8
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-30-3f5256ff03ec> in <module>
----> 1 Test_tdws=twds_model.fit(X_train, y_train, epochs=8, batch_size=144, verbose=2, validation_split=(0.14), shuffle=False) #callbacks=[tensorboard])
~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\keras\engine\training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, max_queue_size, workers, use_multiprocessing, **kwargs)
878 initial_epoch=initial_epoch,
879 steps_per_epoch=steps_per_epoch,
--> 880 validation_steps=validation_steps)
881
882 def evaluate(self,
~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\keras\engine\training_arrays.py in model_iteration(model, inputs, targets, sample_weights, batch_size, epochs, verbose, callbacks, val_inputs, val_targets, val_sample_weights, shuffle, initial_epoch, steps_per_epoch, validation_steps, mode, validation_in_fit, **kwargs)
327
328 # Get outputs.
--> 329 batch_outs = f(ins_batch)
330 if not isinstance(batch_outs, list):
331 batch_outs = [batch_outs]
~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\keras\backend.py in __call__(self, inputs)
3074
3075 fetched = self._callable_fn(*array_vals,
-> 3076 run_metadata=self.run_metadata)
3077 self._call_fetch_callbacks(fetched[-len(self._fetches):])
3078 return nest.pack_sequence_as(self._outputs_structure,
~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\client\session.py in __call__(self, *args, **kwargs)
1437 ret = tf_session.TF_SessionRunCallable(
1438 self._session._session, self._handle, args, status,
-> 1439 run_metadata_ptr)
1440 if run_metadata:
1441 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\framework\errors_impl.py in __exit__(self, type_arg, value_arg, traceback_arg)
526 None, None,
527 compat.as_text(c_api.TF_Message(self.status.status)),
--> 528 c_api.TF_GetCode(self.status.status))
529 # Delete the underlying status object from memory otherwise it stays alive
530 # as there is a reference to status from this from the traceback due to
InvalidArgumentError: Incompatible shapes: [144,1] vs. [144,18,1]
[[{{node loss_2/dense_4_loss/sub}}]]
[[{{node loss_2/mul}}]]
为了完成,y_train 被重塑为 (1200*18,1) 的预期错误:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-47-2a6d0761b794> in <module>
----> 1 Test_tdws=twds_model.fit(X_train, y_train_flat, epochs=8, batch_size=144, verbose=2, validation_split=(0.14), shuffle=False) #callbacks=[tensorboard])
~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\keras\engine\training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, max_queue_size, workers, use_multiprocessing, **kwargs)
774 steps=steps_per_epoch,
775 validation_split=validation_split,
--> 776 shuffle=shuffle)
777
778 # Prepare validation data.
~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\keras\engine\training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, batch_size, check_steps, steps_name, steps, validation_split, shuffle)
2434 # Check that all arrays have the same length.
2435 if not self._distribution_strategy:
-> 2436 training_utils.check_array_lengths(x, y, sample_weights)
2437 if self._is_graph_network and not self.run_eagerly:
2438 # Additional checks to avoid users mistakenly using improper loss fns.
~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\keras\engine\training_utils.py in check_array_lengths(inputs, targets, weights)
454 'the same number of samples as target arrays. '
455 'Found ' + str(list(set_x)[0]) + ' input samples '
--> 456 'and ' + str(list(set_y)[0]) + ' target samples.')
457 if len(set_w) > 1:
458 raise ValueError('All sample_weight arrays should have '
ValueError: Input arrays should have the same number of samples as target arrays. Found 12117 input samples and 218106 target samples
使用的版本是:
Package Version
---------------------- --------------------
- nsorflow-gpu
-ensorflow-gpu 1.13.1
-rotobuf 3.11.3
-umpy 1.18.1
absl-py 0.9.0
antlr4-python3-runtime 4.8
asn1crypto 1.3.0
astor 0.7.1
astropy 3.2.1
astunparse 1.6.3
attrs 19.3.0
audioread 2.1.8
autopep8 1.5.3
backcall 0.1.0
beautifulsoup4 4.9.0
bezier 0.8.0
bkcharts 0.2
bleach 3.1.4
blis 0.2.4
bokeh 1.1.0
boto3 1.9.253
botocore 1.12.253
Bottleneck 1.3.2
cachetools 4.1.0
certifi 2020.4.5.1
cffi 1.14.0
chardet 3.0.4
click 6.7
cloudpickle 0.5.3
cmdstanpy 0.4.0
color 0.1
colorama 0.4.3
colorcet 0.9.1
convertdate 2.2.1
copulas 0.2.5
cryptography 2.8
ctgan 0.2.1
cycler 0.10.0
cymem 2.0.2
Cython 0.29.17
dash 0.26.0
dash-core-components 0.27.2
dash-html-components 0.11.0
dash-renderer 0.13.2
dask 0.18.1
dataclasses 0.6
datashader 0.7.0
datashape 0.5.2
datawig 0.1.10
deap 1.3.0
decorator 4.4.2
defusedxml 0.6.0
deltapy 0.1.1
dill 0.2.9
distributed 1.22.1
docutils 0.14
entrypoints 0.3
ephem 3.7.7.1
et-xmlfile 1.0.1
exrex 0.10.5
Faker 4.0.3
fastai 1.0.60
fastprogress 0.2.2
fbprophet 0.6
fire 0.3.1
Flask 1.0.2
Flask-Compress 1.4.0
future 0.17.1
gast 0.3.3
geojson 2.4.1
geomet 0.2.0.post2
google-auth 1.14.0
google-auth-oauthlib 0.4.1
google-pasta 0.2.0
gplearn 0.4.1
graphviz 0.13.2
grpcio 1.29.0
h5py 2.10.0
HeapDict 1.0.0
holidays 0.10.2
holoviews 1.12.1
html2text 2018.1.9
hyperas 0.4.1
hyperopt 0.1.2
idna 2.6
imageio 2.5.0
imbalanced-learn 0.3.3
imblearn 0.0
importlib-metadata 1.5.0
impyute 0.0.8
ipykernel 5.1.4
ipython 7.13.0
ipython-genutils 0.2.0
ipywidgets 7.5.1
itsdangerous 0.24
jdcal 1.4
jedi 0.16.0
Jinja2 2.11.1
jmespath 0.9.5
joblib 0.13.2
jsonschema 3.2.0
jupyter 1.0.0
jupyter-client 6.1.2
jupyter-console 6.0.0
jupyter-core 4.6.3
Keras 2.2.5
Keras-Applications 1.0.8
Keras-Preprocessing 1.1.2
keras-rectified-adam 0.17.0
kiwisolver 1.2.0
korean-lunar-calendar 0.2.1
librosa 0.7.2
llvmlite 0.32.1
lml 0.0.1
locket 0.2.0
LunarCalendar 0.0.9
Markdown 2.6.11
MarkupSafe 1.1.1
matplotlib 3.2.1
missingpy 0.2.0
mistune 0.8.4
mkl-fft 1.0.15
mkl-random 1.1.0
mkl-service 2.3.0
mock 4.0.2
msgpack 0.5.6
multipledispatch 0.6.0
murmurhash 1.0.2
mxnet 1.4.1
nb-conda 2.2.1
nb-conda-kernels 2.2.3
nbconvert 5.6.1
nbformat 5.0.4
nbstripout 0.3.7
networkx 2.1
notebook 6.0.3
numba 0.49.1
numexpr 2.7.1
numpy 1.19.0
oauthlib 3.1.0
olefile 0.46
opencv-python 4.2.0.34
openpyxl 2.5.5
opt-einsum 3.2.1
packaging 20.3
pandas 1.0.3
pandasvault 0.0.3
pandocfilters 1.4.2
param 1.9.0
parso 0.6.2
partd 0.3.8
patsy 0.5.1
pbr 5.1.3
pickleshare 0.7.5
Pillow 7.0.0
pip 20.0.2
plac 0.9.6
plotly 4.7.1
plotly-express 0.4.1
preshed 2.0.1
prometheus-client 0.7.1
prompt-toolkit 3.0.4
protobuf 3.11.3
psutil 5.4.7
py 1.8.0
pyasn1 0.4.8
pyasn1-modules 0.2.8
pycodestyle 2.6.0
pycparser 2.20
pyct 0.4.5
pyensae 1.3.839
pyexcel 0.5.8
pyexcel-io 0.5.7
Pygments 2.6.1
pykalman 0.9.5
PyMeeus 0.3.7
pymongo 3.8.0
pyOpenSSL 19.1.0
pyparsing 2.4.7
pypi 2.1
pyquickhelper 1.9.3418
pyrsistent 0.16.0
PySocks 1.7.1
pystan 2.19.1.1
python-dateutil 2.8.1
pytz 2019.3
pyviz-comms 0.7.2
PyWavelets 0.5.2
pywin32 227
pywinpty 0.5.7
PyYAML 5.3.1
pyzmq 18.1.1
qtconsole 4.4.4
rdt 0.2.1
RegscorePy 1.1
requests 2.23.0
requests-oauthlib 1.3.0
resampy 0.2.2
retrying 1.3.3
rsa 4.0
s3transfer 0.2.1
scikit-image 0.15.0
scikit-learn 0.23.2
scipy 1.4.1
sdv 0.3.2
seaborn 0.9.0
seasonal 0.3.1
Send2Trash 1.5.0
sentinelsat 0.12.2
setuptools 46.3.0
setuptools-git 1.2
six 1.14.0
sklearn 0.0
sortedcontainers 2.0.4
SoundFile 0.10.3.post1
soupsieve 2.0
spacy 2.1.8
srsly 0.1.0
statsmodels 0.9.0
stopit 1.1.2
sugartensor 1.0.0.2
ta 0.5.25
tb-nightly 1.14.0a20190603
tblib 1.3.2
tensorboard 1.13.1
tensorboard-plugin-wit 1.6.0.post3
tensorflow-estimator 1.13.0
tensorflow-gpu 1.13.1
termcolor 1.1.0
terminado 0.8.3
testpath 0.4.4
text-unidecode 1.3
texttable 1.4.0
tf-estimator-nightly 1.14.0.dev2019060501
Theano 1.0.4
thinc 7.0.8
threadpoolctl 2.1.0
toml 0.10.1
toolz 0.10.0
torch 1.4.0
torchvision 0.5.0
tornado 6.0.4
TPOT 0.10.2
tqdm 4.45.0
traitlets 4.3.3
transforms3d 0.3.1
tsaug 0.2.1
typeguard 2.7.1
typing 3.6.6
update-checker 0.16
urllib3 1.22
utm 0.4.2
wasabi 0.2.2
wcwidth 0.1.9
webencodings 0.5.1
Werkzeug 1.0.1
wheel 0.34.2
widgetsnbextension 3.5.1
win-inet-pton 1.1.0
wincertstore 0.2
wrapt 1.11.2
xarray 0.10.8
xlrd 1.1.0
yahoo-historical 0.3.2
zict 0.1.3
zipp 2.2.0
非常感谢所有指向正在运行的代码的提示 ;-)!
编辑编辑编辑
更新 tensorflow 和 keras 到最新版本后,我收到以下错误。错误仍然存在,尽管 tensorlfow、CUDA 10.1 和 cudnn 8.0.2 被完全删除并重新安装。该错误是由我的原始代码和 Fallen Aparts 示例代码产生的。
UnknownError: Fail to find the dnn implementation.
[[{{node CudnnRNN}}]]
[[sequential/bidirectional/forward_gru/PartitionedCall]] [Op:__inference_train_function_5731]
Function call stack:
train_function -> train_function -> train_function
None
Epoch 1/4
---------------------------------------------------------------------------
UnknownError Traceback (most recent call last)
<ipython-input-1-64eb8afffe02> in <module>
27 print(twds_model.summary())
28
---> 29 twds_model.fit(X_train, y_train, epochs=4)
~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\keras\engine\training.py in _method_wrapper(self, *args, **kwargs)
106 def _method_wrapper(self, *args, **kwargs):
107 if not self._in_multi_worker_mode(): # pylint: disable=protected-access
--> 108 return method(self, *args, **kwargs)
109
110 # Running inside `run_distribute_coordinator` already.
~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\keras\engine\training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
1096 batch_size=batch_size):
1097 callbacks.on_train_batch_begin(step)
-> 1098 tmp_logs = train_function(iterator)
1099 if data_handler.should_sync:
1100 context.async_wait()
~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\eager\def_function.py in __call__(self, *args, **kwds)
778 else:
779 compiler = "nonXla"
--> 780 result = self._call(*args, **kwds)
781
782 new_tracing_count = self._get_tracing_count()
~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\eager\def_function.py in _call(self, *args, **kwds)
838 # Lifting succeeded, so variables are initialized and we can run the
839 # stateless function.
--> 840 return self._stateless_fn(*args, **kwds)
841 else:
842 canon_args, canon_kwds = \
~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\eager\function.py in __call__(self, *args, **kwargs)
2827 with self._lock:
2828 graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
-> 2829 return graph_function._filtered_call(args, kwargs) # pylint: disable=protected-access
2830
2831 @property
~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\eager\function.py in _filtered_call(self, args, kwargs, cancellation_manager)
1846 resource_variable_ops.BaseResourceVariable))],
1847 captured_inputs=self.captured_inputs,
-> 1848 cancellation_manager=cancellation_manager)
1849
1850 def _call_flat(self, args, captured_inputs, cancellation_manager=None):
~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\eager\function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
1922 # No tape is watching; skip to running the function.
1923 return self._build_call_outputs(self._inference_function.call(
-> 1924 ctx, args, cancellation_manager=cancellation_manager))
1925 forward_backward = self._select_forward_and_backward_functions(
1926 args,
~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\eager\function.py in call(self, ctx, args, cancellation_manager)
548 inputs=args,
549 attrs=attrs,
--> 550 ctx=ctx)
551 else:
552 outputs = execute.execute_with_cancellation(
~\Anaconda3\envs\Tensorflow\lib\site-packages\tensorflow\python\eager\execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
58 ctx.ensure_initialized()
59 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60 inputs, attrs, num_outputs)
61 except core._NotOkStatusException as e:
62 if name is not None:
UnknownError: Fail to find the dnn implementation.
[[{{node CudnnRNN}}]]
[[sequential/bidirectional/forward_gru/PartitionedCall]] [Op:__inference_train_function_5731]
Function call stack:
train_function -> train_function -> train_function
各自的版本列表:
Package Version
------------------------ ---------------
- nsorflow-gpu
-ensorflow-gpu 2.3.0
-rotobuf 3.11.3
absl-py 0.9.0
antlr4-python3-runtime 4.8
asn1crypto 1.3.0
astor 0.7.1
astropy 3.2.1
astunparse 1.6.3
attrs 19.3.0
audioread 2.1.8
autopep8 1.5.3
backcall 0.1.0
beautifulsoup4 4.9.0
bezier 0.8.0
bkcharts 0.2
bleach 3.1.4
blis 0.2.4
bokeh 1.1.0
boto3 1.9.253
botocore 1.12.253
Bottleneck 1.3.2
cachetools 4.1.0
certifi 2020.4.5.1
cffi 1.14.0
chardet 3.0.4
click 6.7
cloudpickle 0.5.3
cmdstanpy 0.4.0
color 0.1
colorama 0.4.3
colorcet 0.9.1
convertdate 2.2.1
copulas 0.2.5
cryptography 2.8
ctgan 0.2.1
cycler 0.10.0
cymem 2.0.2
Cython 0.29.17
dash 0.26.0
dash-core-components 0.27.2
dash-html-components 0.11.0
dash-renderer 0.13.2
dask 0.18.1
dataclasses 0.6
datashader 0.7.0
datashape 0.5.2
datawig 0.1.10
deap 1.3.0
decorator 4.4.2
defusedxml 0.6.0
deltapy 0.1.1
dill 0.2.9
distributed 1.22.1
docutils 0.14
entrypoints 0.3
ephem 3.7.7.1
et-xmlfile 1.0.1
exrex 0.10.5
Faker 4.0.3
fastai 1.0.60
fastprogress 0.2.2
fbprophet 0.6
fire 0.3.1
Flask 1.0.2
Flask-Compress 1.4.0
future 0.17.1
gast 0.3.3
geojson 2.4.1
geomet 0.2.0.post2
google-auth 1.14.0
google-auth-oauthlib 0.4.1
google-pasta 0.2.0
gplearn 0.4.1
graphviz 0.13.2
grpcio 1.29.0
h5py 2.10.0
HeapDict 1.0.0
holidays 0.10.2
holoviews 1.12.1
html2text 2018.1.9
hyperas 0.4.1
hyperopt 0.1.2
idna 2.6
imageio 2.5.0
imbalanced-learn 0.3.3
imblearn 0.0
importlib-metadata 1.5.0
impyute 0.0.8
ipykernel 5.1.4
ipython 7.13.0
ipython-genutils 0.2.0
ipywidgets 7.5.1
itsdangerous 0.24
jdcal 1.4
jedi 0.16.0
Jinja2 2.11.1
jmespath 0.9.5
joblib 0.13.2
jsonschema 3.2.0
jupyter 1.0.0
jupyter-client 6.1.2
jupyter-console 6.0.0
jupyter-core 4.6.3
Keras 2.4.3
Keras-Applications 1.0.8
Keras-Preprocessing 1.1.2
keras-rectified-adam 0.17.0
kiwisolver 1.2.0
korean-lunar-calendar 0.2.1
librosa 0.7.2
llvmlite 0.32.1
lml 0.0.1
locket 0.2.0
LunarCalendar 0.0.9
Markdown 2.6.11
MarkupSafe 1.1.1
matplotlib 3.2.1
missingpy 0.2.0
mistune 0.8.4
mkl-fft 1.0.15
mkl-random 1.1.0
mkl-service 2.3.0
mock 4.0.2
msgpack 0.5.6
multipledispatch 0.6.0
murmurhash 1.0.2
mxnet 1.4.1
nb-conda 2.2.1
nb-conda-kernels 2.2.3
nbconvert 5.6.1
nbformat 5.0.4
nbstripout 0.3.7
networkx 2.1
notebook 6.0.3
numba 0.49.1
numexpr 2.7.1
numpy 1.18.5
oauthlib 3.1.0
olefile 0.46
opencv-python 4.2.0.34
openpyxl 2.5.5
opt-einsum 3.2.1
packaging 20.3
pandas 1.0.3
pandasvault 0.0.3
pandocfilters 1.4.2
param 1.9.0
parso 0.6.2
partd 0.3.8
patsy 0.5.1
pbr 5.1.3
pickleshare 0.7.5
Pillow 7.0.0
pip 20.2.2
plac 0.9.6
plotly 4.7.1
plotly-express 0.4.1
preshed 2.0.1
prometheus-client 0.7.1
prompt-toolkit 3.0.4
protobuf 3.11.3
psutil 5.4.7
py 1.8.0
pyasn1 0.4.8
pyasn1-modules 0.2.8
pycodestyle 2.6.0
pycparser 2.20
pyct 0.4.5
pyensae 1.3.839
pyexcel 0.5.8
pyexcel-io 0.5.7
Pygments 2.6.1
pykalman 0.9.5
PyMeeus 0.3.7
pymongo 3.8.0
pyOpenSSL 19.1.0
pyparsing 2.4.7
pypi 2.1
pyquickhelper 1.9.3418
pyrsistent 0.16.0
PySocks 1.7.1
pystan 2.19.1.1
python-dateutil 2.8.1
pytz 2019.3
pyviz-comms 0.7.2
PyWavelets 0.5.2
pywin32 227
pywinpty 0.5.7
PyYAML 5.3.1
pyzmq 18.1.1
qtconsole 4.4.4
rdt 0.2.1
RegscorePy 1.1
requests 2.23.0
requests-oauthlib 1.3.0
resampy 0.2.2
retrying 1.3.3
rsa 4.0
s3transfer 0.2.1
scikit-image 0.15.0
scikit-learn 0.23.2
scipy 1.4.1
sdv 0.3.2
seaborn 0.9.0
seasonal 0.3.1
Send2Trash 1.5.0
sentinelsat 0.12.2
setuptools 46.3.0
setuptools-git 1.2
six 1.14.0
sklearn 0.0
sortedcontainers 2.0.4
SoundFile 0.10.3.post1
soupsieve 2.0
spacy 2.1.8
srsly 0.1.0
statsmodels 0.9.0
stopit 1.1.2
sugartensor 1.0.0.2
ta 0.5.25
tb-nightly 1.14.0a20190603
tblib 1.3.2
tensorboard 2.3.0
tensorboard-plugin-wit 1.7.0
tensorflow-gpu 2.3.0
tensorflow-gpu-estimator 2.3.0
termcolor 1.1.0
terminado 0.8.3
testpath 0.4.4
text-unidecode 1.3
texttable 1.4.0
Theano 1.0.4
thinc 7.0.8
threadpoolctl 2.1.0
toml 0.10.1
toolz 0.10.0
torch 1.4.0
torchvision 0.5.0
tornado 6.0.4
TPOT 0.10.2
tqdm 4.45.0
traitlets 4.3.3
transforms3d 0.3.1
tsaug 0.2.1
typeguard 2.7.1
typing 3.6.6
update-checker 0.16
urllib3 1.22
utm 0.4.2
wasabi 0.2.2
wcwidth 0.1.9
webencodings 0.5.1
Werkzeug 1.0.1
wheel 0.34.2
widgetsnbextension 3.5.1
win-inet-pton 1.1.0
wincertstore 0.2
wrapt 1.11.2
xarray 0.10.8
xlrd 1.1.0
yahoo-historical 0.3.2
zict 0.1.3
zipp 2.2.0
【问题讨论】:
-
你的数据是什么形状的?我用
X_train.shape = (1200, 18, 15)和y_train.shape = (1200, 18, 1)和twds_model.fit(X_train, y_train)进行了检查,效果很好。 -
@Fallen Apart:它们正是这些形状。你有什么版本的 keras、tensorflow 等?也许这就是问题所在?
-
我有 2.2.0,但我想版本应该不是问题。确保再一次拥有这些形状。
-
形状正确(如下)。一定有一些版本差异。但是,在 2.3.0 中也会引发相同的错误。
标签: tensorflow machine-learning multidimensional-array keras lstm